Skip to content

Commit 71f5abc

Browse files
authored
Update 20230306 (#403)
* 为 equalization pass 修复一个 bug * 为 lsq pass 添加一个接口,允许用户传入优化器 * 添加 fuse matmul+add 函数 * 修复了一些 typo * 上传了 QuantZoo 数据集
1 parent e0298ad commit 71f5abc

File tree

13 files changed

+1147
-22
lines changed

13 files changed

+1147
-22
lines changed

ppq/IR/morph.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,38 @@ def fuse_selfattention(self):
10421042
if v is not None: non_empty_attr[k] = v
10431043
op._attributes = non_empty_attr
10441044

1045+
def fuse_matmul_add(self, verbose: bool = True):
1046+
"""
1047+
Fuse Matmul + bias add to PPQBiasFusedMatMul
1048+
1049+
PPQBiasFusedMatMul is a temporary operation which will be splited when exporting.
1050+
"""
1051+
graph, fused = self.graph, False
1052+
for current_op in [_ for _ in graph.operations.values()]:
1053+
if current_op.type != 'MatMul': continue
1054+
1055+
# check down-stream op is add
1056+
next_ops = graph.get_downstream_operations(current_op)
1057+
if len(next_ops) != 1: continue
1058+
if next_ops[0].type != 'Add': continue
1059+
1060+
# check if is a constant add
1061+
fusing_op = next_ops[0]
1062+
if fusing_op.num_of_parameter == 1:
1063+
1064+
# do graph fusion
1065+
bias = fusing_op.parameters[0].value
1066+
graph.remove_operation(fusing_op, keep_coherence=True)
1067+
graph.create_variable(value=bias, is_parameter=True, dest_ops=[current_op])
1068+
current_op.type = 'PPQBiasFusedMatMul'
1069+
fused = True
1070+
1071+
if verbose:
1072+
print(f'Fusing graph op: {current_op.name} + {fusing_op.name}')
1073+
1074+
if not fused:
1075+
ppq_warning("No suitable matmul + add was found, check your graph again.")
1076+
10451077

10461078
class GraphDecomposer(GraphCommandProcessor):
10471079
"""Since PPQ 0.6.4, GraphDecomposer is introduced to split some complex

ppq/parser/onnx_exporter.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,6 @@ def export(self, op: Operation, graph: BaseGraph, **kwargs) -> Operation:
5555
graph.create_variable(value=bias.value, is_parameter=True, dest_ops=[bias_op])
5656
graph.remove_variable(op.inputs[-1])
5757

58-
class PPQBiasFusedMatMulExporter(OperationExporter):
59-
def export(self, op: Operation, graph: BaseGraph, **kwargs) -> Operation:
60-
if op.num_of_input == 3: bias = op.inputs[-1]
61-
assert bias.is_parameter and bias.value is not None, 'MatMul Format Error'
62-
63-
bias_op = graph.create_operation(op_type='Add')
64-
op.type = 'MatMul'
65-
graph.insert_op_after(bias_op, op)
66-
graph.create_variable(value=bias.value, is_parameter=True, dest_ops=[bias_op])
67-
graph.remove_variable(op.inputs[-1])
68-
6958
OP_CONVERTERS = {
7059
'ConstantOfShape': ConstantOfShapeExporter,
7160
'MMCVRoiAlign': MMCVExporter,
@@ -83,7 +72,7 @@ def export(self, op: Operation, graph: BaseGraph, **kwargs) -> Operation:
8372
'QLinearMul': OOSExporter,
8473
'QLinearReduceMean': OOSExporter,
8574
'QLinearSigmoid': OOSExporter,
86-
'PPQBiasFusedMatMul': PPQBiasFusedMatMulExporter
75+
# 'PPQBiasFusedMatMul': PPQBiasFusedMatMulExporter
8776
}
8877

8978
def convert_value(value: Union[int, float, np.ndarray, torch.Tensor]) -> str:

ppq/quantization/algorithm/equalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def key_value_from_upstream(
8282
# ----------------------------------
8383
# step - 3, extract activation from op:
8484
# ----------------------------------
85-
if including_act and op.inputs[0].value is not None:
85+
if including_act and op.outputs[0].value is not None:
8686
a = op.outputs[0].value * act_multiplier
8787
buffer.append(a)
8888

ppq/quantization/algorithm/training.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def __init__(
332332
if self.is_parameter and is_parameter_trainable:
333333
self.param_backup = self.var.value.clone()
334334

335-
# There is 4 checks for training scale:
335+
# There is 4 checks for scale training:
336336
# 1. scale is valid
337337
# 2. state is active
338338
# 3. do not have POWER_OF_2 policy but Must have Linear policy
@@ -348,7 +348,7 @@ def __init__(
348348
self.is_scale_trainable = True
349349
self.scale_backup = self.config.scale.detach().clone()
350350

351-
# There is 4 checks for training offset:
351+
# There is 4 checks for offset training:
352352
# 1. offset is valid
353353
# 2. state is active
354354
# 3. do not have SYMMETRICAL policy
@@ -419,4 +419,4 @@ def __call__(self, tensor: torch.Tensor, config: TensorQuantizationConfig) -> to
419419
quantized = (quantized - offset.detach()) * scale
420420
quantized = quantized
421421
return quantized
422-
422+

ppq/quantization/observer/range.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def hist_to_scale_offset(
240240

241241
losses, quant_bins = [], 2 ** (config.num_of_bits - 1)
242242

243-
# following code is curcial, do not move
243+
# following code is curcial, do not remove
244244
histogram[: int(hist_bins * .002)] = 0
245245
histogram[int(hist_bins * .002)] = 1
246246

ppq/quantization/optim/training.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -709,8 +709,9 @@ class LearnedStepSizePass(TrainingBasedPass):
709709
def __init__(
710710
self, name: str = 'PPQ LSQ Optimization', interested_layers: List[str] = [],
711711
steps: int = 500, gamma: float = 0.0, is_scale_trainable: bool = True,
712-
lr: float = 5e-5, block_size: int = None, expire_device: str = 'cpu',
712+
lr: float = 5e-5, block_size: int = 5, expire_device: str = 'cpu',
713713
collecting_device: str = 'cuda', loss_fn: Callable = torch_mean_square_error,
714+
optimizer: Any = None
714715
) -> None:
715716
super().__init__(name=name)
716717
self.interested_layers = interested_layers
@@ -722,6 +723,7 @@ def __init__(
722723
self.gamma = gamma
723724
self.steps = steps
724725
self.lr = lr
726+
self.optimizer = optimizer
725727

726728
def finetune(
727729
self, steps: int, learning_rate: float, block: TrainableBlock, executor: TorchExecutor,
@@ -764,8 +766,9 @@ def finetune(
764766
return 0, 0
765767

766768
# initilize optimizer.
767-
if optimizer is None:
769+
if self.optimizer is None:
768770
optimizer = torch.optim.Adam(tensors, lr=learning_rate)
771+
else: optimizer = self.optimizer(tensors, lr=learning_rate)
769772

770773
dataset_length = len(qt_inputs)
771774
if dataset_length == 0: raise ValueError('Dataset is empty.')
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Test Quantization System Performace on Image Classification Models with ILSVRC2012 Dataset
2+
3+
# Should contains model file(.onnx)
4+
MODEL_DIR = 'QuantZoo/Model/Imagenet'
5+
6+
# Should contains Calib & Test Img Folder
7+
CALIB_DIR = 'QuantZoo/Data/Imagenet/Calib'
8+
TEST_DIR = 'QuantZoo/Data/Imagenet/Test'
9+
10+
# calibration & test batchsize
11+
BATCHSIZE = 32
12+
13+
# Quantizer Configuration
14+
SYMMETRICAL = True
15+
PER_CHANNEL = True
16+
POWER_OF_2 = False
17+
BIT_WIDTH = 8
18+
19+
# write report to here
20+
REPORT_DIR = 'QuantZoo/Reports'
21+
22+
CONFIGS = [
23+
{
24+
'Model': 'efficientnet_v1_b0',
25+
'Output': ['/features/features.8/features.8.2/Mul_output_0']
26+
},
27+
{
28+
'Model': 'efficientnet_v1_b1',
29+
'Output': ['/features/features.8/features.8.2/Mul_output_0']
30+
},
31+
{
32+
'Model': 'efficientnet_v2_s',
33+
'Output': ['/features/features.7/features.7.2/Mul_output_0']
34+
},
35+
{
36+
'Model': 'mnasnet0_5',
37+
'Output': ['/layers/layers.16/Relu_output_0']
38+
},
39+
{
40+
'Model': 'mnasnet1_0',
41+
'Output': ['/layers/layers.16/Relu_output_0']
42+
},
43+
{
44+
'Model': 'mobilenet_v2',
45+
'Output': ['/features/features.18/features.18.2/Clip_output_0']
46+
},
47+
{
48+
'Model': 'resnet18',
49+
'Output': ['/layer4/layer4.1/relu_1/Relu_output_0']
50+
},
51+
{
52+
'Model': 'resnet50',
53+
'Output': ['/layer4/layer4.2/relu_2/Relu_output_0']
54+
},
55+
56+
{
57+
'Model': 'mobilenet_v3_large',
58+
'Output': ['/classifier/classifier.1/Mul_output_0']
59+
},
60+
{
61+
'Model': 'mobilenet_v3_small',
62+
'Output': ['/classifier/classifier.1/Mul_output_0']
63+
},
64+
{
65+
'Model': 'v100_gpu64@[email protected]_finetune@25',
66+
'Output': ['471']
67+
},
68+
{
69+
'Model': 'v100_gpu64@[email protected]_finetune@25',
70+
'Output': ['471']
71+
},
72+
{
73+
# vit_b_16 requires BATCHSIZE = 1!
74+
'Model': 'vit_b_16',
75+
'Output': ['onnx::Gather_1703']
76+
}
77+
]
78+
79+
import os
80+
81+
import torch
82+
83+
import ppq.lib as PFL
84+
from ppq.api import ENABLE_CUDA_KERNEL, load_onnx_graph
85+
from ppq.core import TargetPlatform
86+
from ppq.executor import TorchExecutor
87+
from ppq.quantization.optim import (LayerwiseEqualizationPass,
88+
LearnedStepSizePass, ParameterQuantizePass,
89+
RuntimeCalibrationPass)
90+
from QuantZoo.Data.Imagenet.Eval import (evaluate_ppq_module_with_imagenet,
91+
load_imagenet_from_directory)
92+
from QuantZoo.Quantizers import MyFP8Quantizer, MyInt8Quantizer
93+
from QuantZoo.Util import error_analyze
94+
95+
96+
calib_loader = load_imagenet_from_directory(
97+
directory=CALIB_DIR, batchsize=BATCHSIZE,
98+
shuffle=False, require_label=False,
99+
num_of_workers=8)
100+
101+
102+
test_loader = load_imagenet_from_directory(
103+
directory=TEST_DIR, batchsize=BATCHSIZE,
104+
shuffle=False, require_label=True,
105+
num_of_workers=8)
106+
107+
108+
with ENABLE_CUDA_KERNEL():
109+
for config in CONFIGS:
110+
model = config['Model']
111+
monitoring_vars = config['Output']
112+
113+
print(f"Ready to run quant benchmark on {model}")
114+
graph = load_onnx_graph(onnx_import_file=os.path.join(MODEL_DIR, model + '.onnx'))
115+
116+
if model == 'vit_b_16':
117+
if BATCHSIZE == 32:
118+
raise Exception('To Evaluate vit_b_16, change batchsize to 1, change calibration method to minmax.')
119+
from ppq.IR import GraphMerger
120+
processor = GraphMerger(graph)
121+
processor.fuse_matmul_add()
122+
processor.fuse_layernorm()
123+
processor.fuse_gelu()
124+
125+
quantizer = MyInt8Quantizer(
126+
graph=graph, sym=SYMMETRICAL, power_of_2=POWER_OF_2,
127+
num_of_bits=BIT_WIDTH, per_channel=PER_CHANNEL)
128+
# quantizer = MyFP8Quantizer(graph=graph)
129+
130+
# convert op to quantable-op
131+
for name, op in graph.operations.items():
132+
if op.type in {'Conv', 'ConvTranspose', 'MatMul', 'Gemm',
133+
'PPQBiasFusedMatMul', 'LayerNormalization'}:
134+
quantizer.quantize_operation(name, platform=TargetPlatform.INT8)
135+
136+
# build quant pipeline.
137+
pipeline = PFL.Pipeline([
138+
# LayerwiseEqualizationPass(iteration=10),
139+
ParameterQuantizePass(),
140+
RuntimeCalibrationPass(),
141+
# LearnedStepSizePass(steps=500, collecting_device='cuda', block_size=5)
142+
])
143+
144+
# call pipeline.
145+
executor = TorchExecutor(graph=graph)
146+
executor.tracing_operation_meta(torch.zeros(size=[BATCHSIZE, 3, 224, 224]).cuda())
147+
148+
pipeline.optimize(
149+
graph=graph, dataloader=calib_loader, verbose=True,
150+
calib_steps=32, collate_fn=lambda x: x.to('cuda'), executor=executor)
151+
152+
# evaluation
153+
acc = evaluate_ppq_module_with_imagenet(
154+
model=graph, imagenet_validation_loader=test_loader,
155+
batchsize=BATCHSIZE, device='cuda', verbose=False)
156+
print(f'Model Classify Accurarcy = {acc: .4f}%')
157+
158+
# error analyze
159+
performance = error_analyze(
160+
graph=graph,
161+
outputs=monitoring_vars,
162+
dataloader=test_loader,
163+
collate_fn=lambda x: x[0].to('cuda'),
164+
verbose=True
165+
)

0 commit comments

Comments
 (0)