9
9
10
10
from brevitas .graph .gpfq import gpfq_mode
11
11
from brevitas .graph .gptq import gptq_mode
12
+ from brevitas_examples .imagenet_classification .ptq .ptq_common import _a2q_layer_filter_fnc
12
13
13
- from .equalization_fixtures import *
14
14
15
+ from .equalization_fixtures import *
15
16
16
17
def apply_gpfq (
17
- calib_loader : DataLoader , model : nn .Module , act_order : bool , use_quant_activations : bool ):
18
+ calib_loader : DataLoader ,
19
+ model : nn .Module ,
20
+ act_order : bool ,
21
+ use_quant_activations : bool ,
22
+ max_accumulator_bit_width : int ,
23
+ max_accumulator_tile_size : int ):
18
24
model .eval ()
19
25
dtype = next (model .parameters ()).dtype
20
26
device = next (model .parameters ()).device
21
27
with torch .no_grad ():
22
- with gpfq_mode (model , use_quant_activations = use_quant_activations ,
23
- act_order = act_order ) as gpfq :
28
+ with gpfq_mode (model ,
29
+ act_order = act_order ,
30
+ a2q_layer_filter_fnc = _a2q_layer_filter_fnc ,
31
+ use_quant_activations = use_quant_activations ,
32
+ max_accumulator_tile_size = max_accumulator_tile_size ,
33
+ max_accumulator_bit_width = max_accumulator_bit_width ) as gpfq :
24
34
gpfq_model = gpfq .model
25
35
for _ in range (gpfq .num_layers ):
26
36
for _ , (images , _ ) in enumerate (calib_loader ):
@@ -31,13 +41,22 @@ def apply_gpfq(
31
41
32
42
33
43
def apply_gptq (
34
- calib_loader : DataLoader , model : nn .Module , act_order : bool , use_quant_activations : bool ):
44
+ calib_loader : DataLoader ,
45
+ model : nn .Module ,
46
+ act_order : bool ,
47
+ use_quant_activations : bool ,
48
+ max_accumulator_bit_width : int ,
49
+ max_accumulator_tile_size : int ):
35
50
model .eval ()
36
51
dtype = next (model .parameters ()).dtype
37
52
device = next (model .parameters ()).device
38
53
with torch .no_grad ():
39
- with gptq_mode (model , use_quant_activations = use_quant_activations ,
40
- act_order = act_order ) as gptq :
54
+ with gptq_mode (model ,
55
+ act_order = act_order ,
56
+ a2q_layer_filter_fnc = _a2q_layer_filter_fnc ,
57
+ use_quant_activations = use_quant_activations ,
58
+ max_accumulator_bit_width = max_accumulator_bit_width ,
59
+ max_accumulator_tile_size = max_accumulator_tile_size ) as gptq :
41
60
gptq_model = gptq .model
42
61
for _ in range (gptq .num_layers ):
43
62
for _ , (images , _ ) in enumerate (calib_loader ):
@@ -54,12 +73,21 @@ def apply_gptq(
54
73
@pytest .mark .parametrize ("use_quant_activations" , [True , False ])
55
74
@pytest .mark .parametrize (
56
75
"apply_gpxq_tuple" , apply_gpxq_func_map .items (), ids = apply_gpxq_func_map .keys ())
57
- def test_toymodels (toy_quant_model , act_order , use_quant_activations , apply_gpxq_tuple , request ):
76
+ @pytest .mark .parametrize ("max_accumulator_bit_width" , [None , 12 , 32 ])
77
+ @pytest .mark .parametrize ("max_accumulator_tile_size" , [None , 32 ])
78
+ def test_toymodels (toy_quant_model , act_order , use_quant_activations , apply_gpxq_tuple , max_accumulator_bit_width , max_accumulator_tile_size , request ):
58
79
59
80
test_id = request .node .callspec .id
81
+ input_quant = test_id .split ('-' )[1 ]
60
82
61
83
torch .manual_seed (SEED )
62
84
85
+ if (max_accumulator_bit_width is None ) and (max_accumulator_tile_size is not None ):
86
+ pytest .skip ("max_accumulator_tile_size doesn't matter if max_accumulator_bit_width is None." )
87
+
88
+ if (max_accumulator_bit_width is not None ) and input_quant .startswith ("MXFloat" ):
89
+ pytest .skip ("AXE does not currently support minifloat formats." )
90
+
63
91
name , apply_gpxq = apply_gpxq_tuple
64
92
65
93
model_class = toy_quant_model
@@ -72,9 +100,25 @@ def test_toymodels(toy_quant_model, act_order, use_quant_activations, apply_gpxq
72
100
model (inp ) # test forward pass and collect scaling factors
73
101
dataset = TensorDataset (inp , inp )
74
102
calib_loader = DataLoader (dataset , batch_size = 16 , num_workers = 0 , pin_memory = True , shuffle = True )
75
-
76
- apply_gpxq (
77
- calib_loader = calib_loader ,
78
- model = model ,
79
- act_order = act_order ,
80
- use_quant_activations = use_quant_activations )
103
+
104
+ if (max_accumulator_bit_width is not None ) and (input_quant == 'None' or not use_quant_activations ):
105
+ # AXE (or A2GPxQ) requires that the quant activations are used. A2GPxQ.single_layer_update
106
+ # will raise a ValueError if AXE.quant_metadata is None (also see GPxQ.process_input). This
107
+ # will happen when `use_quant_activations=False` or when the input to a model is not quantized
108
+ # and `a2q_layer_filter_fnc` does not properly handle it.
109
+ with pytest .raises (ValueError ):
110
+ apply_gpxq (
111
+ calib_loader = calib_loader ,
112
+ model = model ,
113
+ act_order = act_order ,
114
+ use_quant_activations = use_quant_activations ,
115
+ max_accumulator_bit_width = max_accumulator_bit_width ,
116
+ max_accumulator_tile_size = max_accumulator_tile_size )
117
+ else :
118
+ apply_gpxq (
119
+ calib_loader = calib_loader ,
120
+ model = model ,
121
+ act_order = act_order ,
122
+ use_quant_activations = use_quant_activations ,
123
+ max_accumulator_bit_width = max_accumulator_bit_width ,
124
+ max_accumulator_tile_size = max_accumulator_tile_size )
0 commit comments