44
55import pytest
66import torch
7- from torchao .quantization .qat import FakeQuantizedLinear
8- from torchao .quantization .qat .fake_quantizer import (
9- FakeQuantizerBase ,
10- Float8FakeQuantizer ,
11- Int4WeightPreshuffledFakeQuantizer ,
12- )
7+
8+ try :
9+ from torchao .quantization .qat import FakeQuantizedLinear
10+ from torchao .quantization .qat .fake_quantizer import (
11+ FakeQuantizerBase ,
12+ Float8FakeQuantizer ,
13+ Int4WeightFakeQuantizer ,
14+ IntxFakeQuantizer ,
15+ )
16+ except ImportError :
17+ print (
18+ "Missing torchao import, please install or upgrade torchao with: pip install 'torchao>=0.15.0'"
19+ )
1320
1421
1522class _CountingFakeQuantizer (torch .nn .Module ):
@@ -49,22 +56,29 @@ def _test_linear_is_fake_quantized(linear: torch.nn.Linear, qat_scheme: str):
4956 """
5057 Verify that the given linear contains fake quantizers according to the `qat_scheme`.
5158 """
59+ weight_only = False
5260 if qat_scheme == "fp8-int4" :
5361 act_fq_class = Float8FakeQuantizer
54- weight_fq_class = Int4WeightPreshuffledFakeQuantizer
62+ weight_fq_class = Int4WeightFakeQuantizer
5563 min_in_features = 128
5664 elif qat_scheme == "fp8-fp8" :
5765 act_fq_class = Float8FakeQuantizer
5866 weight_fq_class = Float8FakeQuantizer
5967 min_in_features = - 1
68+ elif qat_scheme == "int8" :
69+ act_fq_class = None
70+ weight_fq_class = IntxFakeQuantizer
71+ min_in_features = 128
72+ weight_only = True
6073 else :
6174 raise ValueError (f"Unknown qat_scheme: { qat_scheme } " )
6275
6376 # Check base layer activations and weights
6477 base_layer = getattr (linear , "base_layer" , linear )
6578 if base_layer .in_features >= min_in_features :
6679 assert isinstance (base_layer , FakeQuantizedLinear )
67- assert isinstance (base_layer .activation_fake_quantizer , act_fq_class )
80+ if not weight_only :
81+ assert isinstance (base_layer .activation_fake_quantizer , act_fq_class )
6882 assert isinstance (base_layer .weight_fake_quantizer , weight_fq_class )
6983
7084 # Check lora A and B (only for full_finetuning=False)
@@ -73,22 +87,26 @@ def _test_linear_is_fake_quantized(linear: torch.nn.Linear, qat_scheme: str):
7387 lora_B = linear .lora_B .default
7488 if lora_A .in_features >= min_in_features :
7589 assert isinstance (lora_A , FakeQuantizedLinear )
76- assert isinstance (lora_A .activation_fake_quantizer , act_fq_class )
90+ if not weight_only :
91+ assert isinstance (lora_A .activation_fake_quantizer , act_fq_class )
7792 assert isinstance (lora_A .weight_fake_quantizer , weight_fq_class )
7893 if lora_B .in_features >= min_in_features :
7994 assert isinstance (lora_B , FakeQuantizedLinear )
80- assert isinstance (lora_B .activation_fake_quantizer , act_fq_class )
95+ if not weight_only :
96+ assert isinstance (lora_B .activation_fake_quantizer , act_fq_class )
8197 assert isinstance (lora_B .weight_fake_quantizer , weight_fq_class )
8298
8399
84100def _test_fake_quantizers_are_called (
85101 model : torch .nn .Module ,
86102 example_inputs : Dict ,
87103 full_finetuning : bool ,
104+ qat_scheme : str ,
88105):
89106 """
90107 Verify that the fake quantizers are actually called when the model is called.
91108 """
109+ weight_only = qat_scheme == "int8"
92110
93111 def _swap_fake_quantizers (model : torch .nn .Module ):
94112 for name , child in model .named_children ():
@@ -99,20 +117,23 @@ def _assert_fake_quantizers_are_called(model: torch.nn.Module):
99117 for name , child in model .named_children ():
100118 if full_finetuning :
101119 if isinstance (child , FakeQuantizedLinear ):
102- assert child .activation_fake_quantizer .count == 1
120+ if not weight_only :
121+ assert child .activation_fake_quantizer .count == 1
103122 assert child .weight_fake_quantizer .count == 1
104123 else :
105124 # For LoRA, we only fake quantize the input activations once per block:
106125 # For self_attn, we only fake quantize the q_proj's input activations
107126 # For mlp, we only fake quantize the gate_proj's input activations
108127 if name == "self_attn" :
109128 base_layer = child .q_proj .base_layer
110- assert hasattr (base_layer , "activation_fake_quantizer" )
111- assert base_layer .activation_fake_quantizer .count == 1
129+ if not weight_only :
130+ assert hasattr (base_layer , "activation_fake_quantizer" )
131+ assert base_layer .activation_fake_quantizer .count == 1
112132 elif name == "mlp" :
113133 base_layer = child .gate_proj .base_layer
114- assert hasattr (base_layer , "activation_fake_quantizer" )
115- assert base_layer .activation_fake_quantizer .count == 1
134+ if not weight_only :
135+ assert hasattr (base_layer , "activation_fake_quantizer" )
136+ assert base_layer .activation_fake_quantizer .count == 1
116137 elif isinstance (child , FakeQuantizedLinear ):
117138 # Weight fake quantizers should always be called
118139 assert child .weight_fake_quantizer .count == 1
@@ -124,7 +145,7 @@ def _assert_fake_quantizers_are_called(model: torch.nn.Module):
124145 model .apply (_assert_fake_quantizers_are_called )
125146
126147
127- def _test_model_fake_quantize (qat_scheme : bool , full_finetuning : bool ):
148+ def _test_model_fake_quantize (qat_scheme : str , full_finetuning : bool ):
128149 """
129150 Test that all linear layers in the model are fake quantized according to the `qat_scheme`.
130151 """
@@ -141,16 +162,16 @@ def _test_model_fake_quantize(qat_scheme: bool, full_finetuning: bool):
141162 _test_linear_is_fake_quantized (layer .mlp .up_proj , qat_scheme )
142163 _test_linear_is_fake_quantized (layer .mlp .down_proj , qat_scheme )
143164 inputs = tokenizer ("How are you?" , return_tensors = "pt" )
144- _test_fake_quantizers_are_called (model , inputs , full_finetuning )
165+ _test_fake_quantizers_are_called (model , inputs , full_finetuning , qat_scheme )
145166
146167
147168# TODO: there are bad interactions across tests right now, need to figure out
148169# how to disable model caching before re-enabling this test
149- @pytest .mark .parametrize ("qat_scheme" , ["fp8-int4" , "fp8-fp8" ])
150- def _test_full_model_fake_quantize (qat_scheme : bool ):
170+ @pytest .mark .parametrize ("qat_scheme" , ["fp8-int4" , "fp8-fp8" , "int8" ])
171+ def _test_full_model_fake_quantize (qat_scheme : str ):
151172 _test_model_fake_quantize (qat_scheme , full_finetuning = True )
152173
153174
154- @pytest .mark .parametrize ("qat_scheme" , ["fp8-int4" , "fp8-fp8" ])
155- def test_lora_model_fake_quantize (qat_scheme : bool ):
175+ @pytest .mark .parametrize ("qat_scheme" , ["fp8-int4" , "fp8-fp8" , "int8" ])
176+ def test_lora_model_fake_quantize (qat_scheme : str ):
156177 _test_model_fake_quantize (qat_scheme , full_finetuning = False )
0 commit comments