Skip to content

Commit 65fdd21

Browse files
committed
Feat (axe): extended unit testing
1 parent 89b3369 commit 65fdd21

File tree

2 files changed

+61
-16
lines changed

2 files changed

+61
-16
lines changed

tests/brevitas/graph/equalization_fixtures.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -411,14 +411,15 @@ def forward(self, x):
411411

412412

413413
@pytest_cases.fixture
414-
def quant_convdepthconv_model():
414+
def quant_convdepthconv_model(input_quant, weight_quant):
415415

416416
class QuantConvDepthConvModel(nn.Module):
417417

418418
def __init__(self) -> None:
419419
super().__init__()
420420
self.conv = qnn.QuantConv2d(3, 16, kernel_size=3)
421-
self.conv_0 = qnn.QuantConv2d(16, 16, kernel_size=1, groups=16)
421+
self.conv_0 = qnn.QuantConv2d(16, 16, kernel_size=1, groups=16,
422+
input_quant=input_quant, weight_quant=weight_quant)
422423
self.relu = qnn.QuantReLU(return_quant_tensor=True)
423424

424425
def forward(self, x):

tests/brevitas/graph/test_gpxq.py

+58-14
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,28 @@
99

1010
from brevitas.graph.gpfq import gpfq_mode
1111
from brevitas.graph.gptq import gptq_mode
12+
from brevitas_examples.imagenet_classification.ptq.ptq_common import _a2q_layer_filter_fnc
1213

13-
from .equalization_fixtures import *
1414

15+
from .equalization_fixtures import *
1516

1617
def apply_gpfq(
17-
calib_loader: DataLoader, model: nn.Module, act_order: bool, use_quant_activations: bool):
18+
calib_loader: DataLoader,
19+
model: nn.Module,
20+
act_order: bool,
21+
use_quant_activations: bool,
22+
max_accumulator_bit_width: int,
23+
max_accumulator_tile_size: int):
1824
model.eval()
1925
dtype = next(model.parameters()).dtype
2026
device = next(model.parameters()).device
2127
with torch.no_grad():
22-
with gpfq_mode(model, use_quant_activations=use_quant_activations,
23-
act_order=act_order) as gpfq:
28+
with gpfq_mode(model,
29+
act_order=act_order,
30+
a2q_layer_filter_fnc=_a2q_layer_filter_fnc,
31+
use_quant_activations=use_quant_activations,
32+
max_accumulator_tile_size=max_accumulator_tile_size,
33+
max_accumulator_bit_width=max_accumulator_bit_width) as gpfq:
2434
gpfq_model = gpfq.model
2535
for _ in range(gpfq.num_layers):
2636
for _, (images, _) in enumerate(calib_loader):
@@ -31,13 +41,22 @@ def apply_gpfq(
3141

3242

3343
def apply_gptq(
34-
calib_loader: DataLoader, model: nn.Module, act_order: bool, use_quant_activations: bool):
44+
calib_loader: DataLoader,
45+
model: nn.Module,
46+
act_order: bool,
47+
use_quant_activations: bool,
48+
max_accumulator_bit_width: int,
49+
max_accumulator_tile_size: int):
3550
model.eval()
3651
dtype = next(model.parameters()).dtype
3752
device = next(model.parameters()).device
3853
with torch.no_grad():
39-
with gptq_mode(model, use_quant_activations=use_quant_activations,
40-
act_order=act_order) as gptq:
54+
with gptq_mode(model,
55+
act_order=act_order,
56+
a2q_layer_filter_fnc=_a2q_layer_filter_fnc,
57+
use_quant_activations=use_quant_activations,
58+
max_accumulator_bit_width=max_accumulator_bit_width,
59+
max_accumulator_tile_size=max_accumulator_tile_size) as gptq:
4160
gptq_model = gptq.model
4261
for _ in range(gptq.num_layers):
4362
for _, (images, _) in enumerate(calib_loader):
@@ -54,12 +73,21 @@ def apply_gptq(
5473
@pytest.mark.parametrize("use_quant_activations", [True, False])
5574
@pytest.mark.parametrize(
5675
"apply_gpxq_tuple", apply_gpxq_func_map.items(), ids=apply_gpxq_func_map.keys())
57-
def test_toymodels(toy_quant_model, act_order, use_quant_activations, apply_gpxq_tuple, request):
76+
@pytest.mark.parametrize("max_accumulator_bit_width", [None, 12, 32])
77+
@pytest.mark.parametrize("max_accumulator_tile_size", [None, 32])
78+
def test_toymodels(toy_quant_model, act_order, use_quant_activations, apply_gpxq_tuple, max_accumulator_bit_width, max_accumulator_tile_size, request):
5879

5980
test_id = request.node.callspec.id
81+
input_quant = test_id.split('-')[1]
6082

6183
torch.manual_seed(SEED)
6284

85+
if (max_accumulator_bit_width is None) and (max_accumulator_tile_size is not None):
86+
pytest.skip("max_accumulator_tile_size doesn't matter if max_accumulator_bit_width is None.")
87+
88+
if (max_accumulator_bit_width is not None) and input_quant.startswith("MXFloat"):
89+
pytest.skip("AXE does not currently support minifloat formats.")
90+
6391
name, apply_gpxq = apply_gpxq_tuple
6492

6593
model_class = toy_quant_model
@@ -72,9 +100,25 @@ def test_toymodels(toy_quant_model, act_order, use_quant_activations, apply_gpxq
72100
model(inp) # test forward pass and collect scaling factors
73101
dataset = TensorDataset(inp, inp)
74102
calib_loader = DataLoader(dataset, batch_size=16, num_workers=0, pin_memory=True, shuffle=True)
75-
76-
apply_gpxq(
77-
calib_loader=calib_loader,
78-
model=model,
79-
act_order=act_order,
80-
use_quant_activations=use_quant_activations)
103+
104+
if (max_accumulator_bit_width is not None) and (input_quant == 'None' or not use_quant_activations):
105+
# AXE (or A2GPxQ) requires that the quant activations are used. A2GPxQ.single_layer_update
106+
# will raise a ValueError if AXE.quant_metadata is None (also see GPxQ.process_input). This
107+
# will happen when `use_quant_activations=False` or when the input to a model is not quantized
108+
# and `a2q_layer_filter_fnc` does not properly handle it.
109+
with pytest.raises(ValueError):
110+
apply_gpxq(
111+
calib_loader=calib_loader,
112+
model=model,
113+
act_order=act_order,
114+
use_quant_activations=use_quant_activations,
115+
max_accumulator_bit_width=max_accumulator_bit_width,
116+
max_accumulator_tile_size=max_accumulator_tile_size)
117+
else:
118+
apply_gpxq(
119+
calib_loader=calib_loader,
120+
model=model,
121+
act_order=act_order,
122+
use_quant_activations=use_quant_activations,
123+
max_accumulator_bit_width=max_accumulator_bit_width,
124+
max_accumulator_tile_size=max_accumulator_tile_size)

0 commit comments

Comments
 (0)