Skip to content

Commit ab4061e

Browse files
add weight-only int8 QAT scheme and update tests for torchao 0.15.0 (#3859)
* add int8 weight-only QAT scheme, add test, fix tests for current torchao version * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change quantization to PerAxis * lambda =/ * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add torchao messages, remove group_size from int8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * raise exception on missing torchao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * touch up the torchao imports * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ca0ecf1 commit ab4061e

File tree

2 files changed

+87
-35
lines changed

2 files changed

+87
-35
lines changed

tests/utils/test_qat.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,19 @@
44

55
import pytest
66
import torch
7-
from torchao.quantization.qat import FakeQuantizedLinear
8-
from torchao.quantization.qat.fake_quantizer import (
9-
FakeQuantizerBase,
10-
Float8FakeQuantizer,
11-
Int4WeightPreshuffledFakeQuantizer,
12-
)
7+
8+
try:
9+
from torchao.quantization.qat import FakeQuantizedLinear
10+
from torchao.quantization.qat.fake_quantizer import (
11+
FakeQuantizerBase,
12+
Float8FakeQuantizer,
13+
Int4WeightFakeQuantizer,
14+
IntxFakeQuantizer,
15+
)
16+
except ImportError:
17+
print(
18+
"Missing torchao import, please install or upgrade torchao with: pip install 'torchao>=0.15.0'"
19+
)
1320

1421

1522
class _CountingFakeQuantizer(torch.nn.Module):
@@ -49,22 +56,29 @@ def _test_linear_is_fake_quantized(linear: torch.nn.Linear, qat_scheme: str):
4956
"""
5057
Verify that the given linear contains fake quantizers according to the `qat_scheme`.
5158
"""
59+
weight_only = False
5260
if qat_scheme == "fp8-int4":
5361
act_fq_class = Float8FakeQuantizer
54-
weight_fq_class = Int4WeightPreshuffledFakeQuantizer
62+
weight_fq_class = Int4WeightFakeQuantizer
5563
min_in_features = 128
5664
elif qat_scheme == "fp8-fp8":
5765
act_fq_class = Float8FakeQuantizer
5866
weight_fq_class = Float8FakeQuantizer
5967
min_in_features = -1
68+
elif qat_scheme == "int8":
69+
act_fq_class = None
70+
weight_fq_class = IntxFakeQuantizer
71+
min_in_features = 128
72+
weight_only = True
6073
else:
6174
raise ValueError(f"Unknown qat_scheme: {qat_scheme}")
6275

6376
# Check base layer activations and weights
6477
base_layer = getattr(linear, "base_layer", linear)
6578
if base_layer.in_features >= min_in_features:
6679
assert isinstance(base_layer, FakeQuantizedLinear)
67-
assert isinstance(base_layer.activation_fake_quantizer, act_fq_class)
80+
if not weight_only:
81+
assert isinstance(base_layer.activation_fake_quantizer, act_fq_class)
6882
assert isinstance(base_layer.weight_fake_quantizer, weight_fq_class)
6983

7084
# Check lora A and B (only for full_finetuning=False)
@@ -73,22 +87,26 @@ def _test_linear_is_fake_quantized(linear: torch.nn.Linear, qat_scheme: str):
7387
lora_B = linear.lora_B.default
7488
if lora_A.in_features >= min_in_features:
7589
assert isinstance(lora_A, FakeQuantizedLinear)
76-
assert isinstance(lora_A.activation_fake_quantizer, act_fq_class)
90+
if not weight_only:
91+
assert isinstance(lora_A.activation_fake_quantizer, act_fq_class)
7792
assert isinstance(lora_A.weight_fake_quantizer, weight_fq_class)
7893
if lora_B.in_features >= min_in_features:
7994
assert isinstance(lora_B, FakeQuantizedLinear)
80-
assert isinstance(lora_B.activation_fake_quantizer, act_fq_class)
95+
if not weight_only:
96+
assert isinstance(lora_B.activation_fake_quantizer, act_fq_class)
8197
assert isinstance(lora_B.weight_fake_quantizer, weight_fq_class)
8298

8399

84100
def _test_fake_quantizers_are_called(
85101
model: torch.nn.Module,
86102
example_inputs: Dict,
87103
full_finetuning: bool,
104+
qat_scheme: str,
88105
):
89106
"""
90107
Verify that the fake quantizers are actually called when the model is called.
91108
"""
109+
weight_only = qat_scheme == "int8"
92110

93111
def _swap_fake_quantizers(model: torch.nn.Module):
94112
for name, child in model.named_children():
@@ -99,20 +117,23 @@ def _assert_fake_quantizers_are_called(model: torch.nn.Module):
99117
for name, child in model.named_children():
100118
if full_finetuning:
101119
if isinstance(child, FakeQuantizedLinear):
102-
assert child.activation_fake_quantizer.count == 1
120+
if not weight_only:
121+
assert child.activation_fake_quantizer.count == 1
103122
assert child.weight_fake_quantizer.count == 1
104123
else:
105124
# For LoRA, we only fake quantize the input activations once per block:
106125
# For self_attn, we only fake quantize the q_proj's input activations
107126
# For mlp, we only fake quantize the gate_proj's input activations
108127
if name == "self_attn":
109128
base_layer = child.q_proj.base_layer
110-
assert hasattr(base_layer, "activation_fake_quantizer")
111-
assert base_layer.activation_fake_quantizer.count == 1
129+
if not weight_only:
130+
assert hasattr(base_layer, "activation_fake_quantizer")
131+
assert base_layer.activation_fake_quantizer.count == 1
112132
elif name == "mlp":
113133
base_layer = child.gate_proj.base_layer
114-
assert hasattr(base_layer, "activation_fake_quantizer")
115-
assert base_layer.activation_fake_quantizer.count == 1
134+
if not weight_only:
135+
assert hasattr(base_layer, "activation_fake_quantizer")
136+
assert base_layer.activation_fake_quantizer.count == 1
116137
elif isinstance(child, FakeQuantizedLinear):
117138
# Weight fake quantizers should always be called
118139
assert child.weight_fake_quantizer.count == 1
@@ -124,7 +145,7 @@ def _assert_fake_quantizers_are_called(model: torch.nn.Module):
124145
model.apply(_assert_fake_quantizers_are_called)
125146

126147

127-
def _test_model_fake_quantize(qat_scheme: bool, full_finetuning: bool):
148+
def _test_model_fake_quantize(qat_scheme: str, full_finetuning: bool):
128149
"""
129150
Test that all linear layers in the model are fake quantized according to the `qat_scheme`.
130151
"""
@@ -141,16 +162,16 @@ def _test_model_fake_quantize(qat_scheme: bool, full_finetuning: bool):
141162
_test_linear_is_fake_quantized(layer.mlp.up_proj, qat_scheme)
142163
_test_linear_is_fake_quantized(layer.mlp.down_proj, qat_scheme)
143164
inputs = tokenizer("How are you?", return_tensors = "pt")
144-
_test_fake_quantizers_are_called(model, inputs, full_finetuning)
165+
_test_fake_quantizers_are_called(model, inputs, full_finetuning, qat_scheme)
145166

146167

147168
# TODO: there are bad interactions across tests right now, need to figure out
148169
# how to disable model caching before re-enabling this test
149-
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8"])
150-
def _test_full_model_fake_quantize(qat_scheme: bool):
170+
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8"])
171+
def _test_full_model_fake_quantize(qat_scheme: str):
151172
_test_model_fake_quantize(qat_scheme, full_finetuning = True)
152173

153174

154-
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8"])
155-
def test_lora_model_fake_quantize(qat_scheme: bool):
175+
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8"])
176+
def test_lora_model_fake_quantize(qat_scheme: str):
156177
_test_model_fake_quantize(qat_scheme, full_finetuning = False)

unsloth/models/_utils.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@
175175
# Stop "Special tokens have been added in the vocabulary, ..."
176176
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL + 1)
177177

178+
TORCHAO_MSG = "Error: torchao not found, please install with `pip install torchao`"
179+
178180

179181
# Ignore logging messages
180182
class HideLoggingMessage(logging.Filter):
@@ -2211,9 +2213,12 @@ def _prepare_model_for_qat(
22112213
QAT can be optionally combined with LoRA fine-tuning to for additional throughput improvement.
22122214
For more details: https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700
22132215
"""
2214-
from torchao.quantization import PerRow, quantize_
2215-
from torchao.quantization.granularity import PerGroup, PerAxis
2216-
from torchao.quantization.qat import QATConfig
2216+
try:
2217+
from torchao.quantization import PerRow, quantize_
2218+
from torchao.quantization.granularity import PerGroup, PerAxis
2219+
from torchao.quantization.qat import QATConfig
2220+
except ImportError:
2221+
raise ImportError(TORCHAO_MSG)
22172222

22182223
# Gemma3 models have issues with int8 embedding quantization due to their
22192224
# large vocabulary size (262144). Auto-switch to int4 weight-only instead.
@@ -2230,8 +2235,10 @@ def _prepare_model_for_qat(
22302235
if not isinstance(qat_scheme, TorchAOConfig):
22312236
torchao_config: Optional[TorchAOConfig] = None
22322237
if qat_scheme == "fp8-int4":
2233-
from torchao.quantization import Float8DynamicActivationInt4WeightConfig
2234-
2238+
try:
2239+
from torchao.quantization import Float8DynamicActivationInt4WeightConfig
2240+
except ImportError:
2241+
raise ImportError(TORCHAO_MSG)
22352242
group_size = 128
22362243
base_config = Float8DynamicActivationInt4WeightConfig()
22372244
filter_fn = (
@@ -2243,20 +2250,26 @@ def _prepare_model_for_qat(
22432250
base_config_and_filter_fns = [(base_config, filter_fn)],
22442251
)
22452252
elif qat_scheme == "fp8-fp8":
2246-
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
2247-
2253+
try:
2254+
from torchao.quantization import (
2255+
Float8DynamicActivationFloat8WeightConfig,
2256+
)
2257+
except ImportError:
2258+
raise ImportError(TORCHAO_MSG)
22482259
base_config = Float8DynamicActivationFloat8WeightConfig(
22492260
granularity = PerRow()
22502261
)
22512262
torchao_config = TorchAOConfig(
22522263
qat_scheme = qat_scheme, base_config_and_filter_fns = [(base_config, None)]
22532264
)
22542265
elif qat_scheme == "int8-int4":
2255-
from torchao.quantization import (
2256-
Int8DynamicActivationIntxWeightConfig,
2257-
IntxWeightOnlyConfig,
2258-
)
2259-
2266+
try:
2267+
from torchao.quantization import (
2268+
Int8DynamicActivationIntxWeightConfig,
2269+
IntxWeightOnlyConfig,
2270+
)
2271+
except ImportError:
2272+
raise ImportError(TORCHAO_MSG)
22602273
torchao_config = TorchAOConfig(
22612274
qat_scheme = qat_scheme,
22622275
base_config_and_filter_fns = [
@@ -2276,8 +2289,10 @@ def _prepare_model_for_qat(
22762289
prequantization_transform = _untie_input_output_embeddings,
22772290
)
22782291
elif qat_scheme == "int4":
2279-
from torchao.quantization import Int4WeightOnlyConfig
2280-
2292+
try:
2293+
from torchao.quantization import Int4WeightOnlyConfig
2294+
except ImportError:
2295+
raise ImportError(TORCHAO_MSG)
22812296
group_size = 128
22822297
base_config = Int4WeightOnlyConfig(group_size = group_size)
22832298
filter_fn = (
@@ -2288,6 +2303,22 @@ def _prepare_model_for_qat(
22882303
qat_scheme = qat_scheme,
22892304
base_config_and_filter_fns = [(base_config, filter_fn)],
22902305
)
2306+
elif qat_scheme == "int8":
2307+
try:
2308+
from torchao.quantization import IntxWeightOnlyConfig
2309+
from torchao.quantization.granularity import PerAxis
2310+
except ImportError:
2311+
raise ImportError(TORCHAO_MSG)
2312+
2313+
base_config = IntxWeightOnlyConfig(
2314+
weight_dtype = torch.int8,
2315+
granularity = PerAxis(0),
2316+
)
2317+
filter_fn = lambda m, _: isinstance(m, torch.nn.Linear)
2318+
torchao_config = TorchAOConfig(
2319+
qat_scheme = qat_scheme,
2320+
base_config_and_filter_fns = [(base_config, filter_fn)],
2321+
)
22912322
else:
22922323
raise ValueError(f"Unexpected QAT scheme {qat_scheme}")
22932324
assert torchao_config is not None, f"TorchAOConfig was not set for {qat_scheme}"

0 commit comments

Comments
 (0)