diff --git a/requirements_frameworks.txt b/requirements_frameworks.txt
index bf6a2c314..0334d3635 100644
--- a/requirements_frameworks.txt
+++ b/requirements_frameworks.txt
@@ -17,6 +17,7 @@ dglgo==0.0.2
tflite
+paddleslim==2.6.0
paddlepaddle==2.6.0
--extra-index-url https://mirror.baidu.com/pypi/simple
diff --git a/src/benchmark/frameworks/paddlepaddle/paddlepaddle_process.py b/src/benchmark/frameworks/paddlepaddle/paddlepaddle_process.py
index 9b3119716..387a75923 100644
--- a/src/benchmark/frameworks/paddlepaddle/paddlepaddle_process.py
+++ b/src/benchmark/frameworks/paddlepaddle/paddlepaddle_process.py
@@ -24,10 +24,11 @@ def _fill_command_line(self):
model = self._test.model.model
params = self._test.model.weight
dataset = self._test.dataset.path if self._test.dataset else None
+ iteration = self._test.indep_parameters.iteration
batch = self._test.indep_parameters.batch_size
device = self._test.indep_parameters.device
- common_params = (f'-m {model} -p {params} -i {dataset} -b {batch} -d {device} '
+ common_params = (f'-m {model} -p {params} -ni {iteration} -i {dataset} -b {batch} -d {device} '
f'--report_path {self.report_path}')
common_params = self._add_optional_argument_to_cmd_line(common_params, '-i', dataset)
diff --git a/src/configs/paddle_quantization_config_template.xml b/src/configs/paddle_quantization_config_template.xml
new file mode 100644
index 000000000..00737078a
--- /dev/null
+++ b/src/configs/paddle_quantization_config_template.xml
@@ -0,0 +1,28 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/inference/inference_paddlepaddle.py b/src/inference/inference_paddlepaddle.py
index 80debca61..ed4ecc527 100644
--- a/src/inference/inference_paddlepaddle.py
+++ b/src/inference/inference_paddlepaddle.py
@@ -126,8 +126,8 @@ def inference_paddlepaddle(predictor, number_iter, get_slice, test_duration):
input_info = predictor.get_input_names()
outputs = predictor.get_output_names()
if number_iter > 1:
- time_infer, _ = loop_inference(number_iter, test_duration)(inference_iteration)(get_slice,
- input_info, predictor)
+ loop_results = loop_inference(number_iter, test_duration)(inference_iteration)(get_slice, input_info, predictor)
+ time_infer = loop_results['time_infer']
else:
exec_time = inference_iteration(get_slice, input_info, predictor)
result = {}
@@ -142,6 +142,7 @@ def inference_paddlepaddle(predictor, number_iter, get_slice, test_duration):
def inference_iteration(get_slice, input_info, predictor):
for name, data in get_slice().items():
input_tensor = predictor.get_input_handle(name)
+ input_tensor.reshape(data.shape)
input_tensor.copy_from_cpu(data)
_, exec_time = infer_slice(predictor)
return exec_time
diff --git a/src/quantization/paddlepaddle/README.md b/src/quantization/paddlepaddle/README.md
new file mode 100644
index 000000000..3bf31778e
--- /dev/null
+++ b/src/quantization/paddlepaddle/README.md
@@ -0,0 +1,48 @@
+# PaddlePaddle quantization script
+
+Script name:
+
+```bash
+quantization_paddlepaddle.py
+```
+
+Required arguments:
+
+- `-c / --config` is a path to the file containing information
+ about quantization process in the xml-format. Template of the configuration file
+ located [here][config_path].
+
+Description of parameters:
+
+`Model` contains information about model to be quantized:
+- `Name` is a name of the model.
+- `PathPrefix` is a path to the model files without the extensions (.pdmodel, .pdiparams).
+- `ModelDir` is a directory with the model.
+- `ModelFileName` is a file name of the model description.
+- `ParamsFileName` is a file name of the model parameters.
+
+`Dataset` contains information about dataset for the model calibration:
+- `Name` is a dataset name.
+- `Path` is a path to the directory that contains input data.
+- `Mean` is a mean value for preprocessing data.
+- `Std` is a scale value for preprocessing data.
+- `ChannelSwap` is a flag to transpose for image channels. For RGB - 2, 1, 0. For BGR - 0, 1, 2.
+- `ResizeResolution` is an image size for preprocessing data. Example: 224, 224.
+- `BatchSize` is a batch size.
+- `BatchNum` is the total number of batches
+
+`QuantizationParameters` contains information about the model input layer:
+- `InputShape` is a shape of the model's input layer.
+- `InputName` is a name of the model's input layer.
+- `SaveDir` is a directory for the quantized model to be saved.
+- `Algorithm` specifies method to calculate the quantization scale factor. Available: 'KL', 'hist', 'mse', 'avg',\
+- 'abs_max'. If algo='KL', use [KL-divergent method][KL] to get the scale factor. If algo='hist', use the hist_percent \
+- of histogram to get the scale factor. If algo='mse', search for the best scale factor which makes the \
+- [mse loss minimal][MSE]. Use one batch of data for mse is enough. If algo='avg', use the average of abs_max values \
+- to get the scale factor. If algo='abs_max', use abs_max method to get the scale factor. Default: 'hist'.
+
+
+
+[config_path]: ../../configs/paddle_quantization_config_template.xml
+[KL] https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
+[MSE] https://en.wikipedia.org/wiki/Minimum_mean_square_error
diff --git a/src/quantization/paddlepaddle/__init__.py b/src/quantization/paddlepaddle/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/quantization/paddlepaddle/parameters.py b/src/quantization/paddlepaddle/parameters.py
new file mode 100644
index 000000000..ed6848c86
--- /dev/null
+++ b/src/quantization/paddlepaddle/parameters.py
@@ -0,0 +1,133 @@
+import sys
+from pathlib import Path
+import random
+import numpy as np
+import paddle
+from paddle.io import Dataset
+import ast
+from paddle.io import DataLoader
+from paddleslim.quant import quant_post_static
+import cv2
+sys.path.append(str(Path(__file__).resolve().parents[1]))
+from utils import ArgumentsParser # noqa: E402
+
+
+class PaddleDatasetReader(Dataset):
+ def __init__(self, args, log):
+ super(PaddleDatasetReader, self).__init__()
+ self.log = log
+ self.log.info('Parsing dataset arguments.')
+ self.data_dir = args['Path']
+
+ self.resize_size = ast.literal_eval(args['ResizeResolution'])
+ self.mean = np.array((np.asarray(ast.literal_eval(args['Mean']), dtype=np.float32)
+ if args['Mean'] is not None else [0., 0., 0.])).reshape((3, 1, 1))
+ self.std = np.array((np.asarray(ast.literal_eval(args['Std']), dtype=np.float32)
+ if args['Std'] is not None else [1., 1., 1.])).reshape((3, 1, 1))
+ self.channel_swap = ast.literal_eval(args['ChannelSwap']) if args['ChannelSwap'] is not None else [2, 0, 1]
+ self.batch_size = int(args['BatchSize'])
+ self.batch_num = int(args['BatchNum'])
+ self.dataset = list(Path(self.data_dir).glob('*'))
+ random.shuffle(self.dataset)
+ self.dataset_iter = iter(self.dataset)
+
+ def __getitem__(self, index):
+ image_path = str(self.dataset[index].absolute())
+ data = self.process_image(image_path)
+ return data
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def process_image(self, image_path):
+
+ img = cv2.imread(image_path)
+ if img.size == 0:
+ self.log.info('failed to read:', image_path)
+ return None
+ img = cv2.resize(img, self.resize_size)
+
+ img = img.astype('float32').transpose(tuple(self.channel_swap)) / 255
+ img -= self.mean
+ img /= self.std
+
+ return img
+
+
+class PaddleQuantizationProcess:
+ def __init__(self, log, model_reader, dataset, quant_params):
+ self.log = log
+ self.model_reader = model_reader
+ self.dataset = dataset
+ self.quant_params = quant_params
+
+ def transform_fn(self):
+ for data in self.dataset:
+ yield [data.astype(np.float32)]
+
+ def quantization_paddle(self):
+ place = paddle.CPUPlace()
+ exe = paddle.static.Executor(place)
+ data_loader = DataLoader(
+ self.dataset,
+ places=place,
+ feed_list=[self.quant_params.image],
+ drop_last=False,
+ return_list=False,
+ batch_size=self.dataset.batch_size,
+ shuffle=False)
+
+ quant_post_static(
+ executor=exe,
+ model_dir=self.model_reader.model_dir,
+ quantize_model_path=self.quant_params.save_dir,
+ data_loader=data_loader,
+ model_filename=self.model_reader.model_filename,
+ params_filename=self.model_reader.params_filename,
+ batch_size=self.dataset.batch_size,
+ batch_nums=self.dataset.batch_num,
+ algo=self.quant_params.algo,
+ round_type='round',
+ hist_percent=0.9999,
+ is_full_quantize=True,
+ bias_correction=False,
+ onnx_format=False)
+
+
+class PaddleModelReader(ArgumentsParser):
+ def __init__(self, log):
+ super().__init__(log)
+
+ def _get_arguments(self):
+ self._log.info('Parsing model arguments.')
+ self.model_name = self.args['Name']
+ self.path_prefix = self.args['PathPrefix']
+ self.model_dir = self.args['ModelDir']
+ self.model_filename = self.args['ModelFileName']
+ self.params_filename = self.args['ParamsFileName']
+
+ def dict_for_iter_log(self):
+ return {
+ 'Name': self.model_name,
+ 'Model path prefix': self.path_prefix,
+ }
+
+
+class PaddleQuantParamReader(ArgumentsParser):
+ def __init__(self, log):
+ super().__init__(log)
+
+ def dict_for_iter_log(self):
+ return {
+ 'InputShape': self.input_shape,
+ 'InputName': self.input_name,
+ 'SaveDir': self.save_dir,
+ 'Algorithm': self.algo,
+ }
+
+ def _get_arguments(self):
+ self.input_shape = ast.literal_eval(self.args['InputShape'])
+ self.image = paddle.static.data(name=self.args['InputName'], shape=[None] + self.input_shape, dtype='float32')
+ self.input_name = self.args['InputName']
+ self.save_dir = self.args['SaveDir']
+ self.algo = self.args['Algorithm']
diff --git a/src/quantization/paddlepaddle/quantization_paddlepaddle.py b/src/quantization/paddlepaddle/quantization_paddlepaddle.py
new file mode 100644
index 000000000..d834c45c3
--- /dev/null
+++ b/src/quantization/paddlepaddle/quantization_paddlepaddle.py
@@ -0,0 +1,55 @@
+import paddle
+import argparse
+import sys
+import traceback
+from pathlib import Path
+from parameters import PaddleModelReader, PaddleDatasetReader, PaddleQuantizationProcess, PaddleQuantParamReader
+sys.path.append(str(Path(__file__).resolve().parents[3]))
+from src.utils.logger_conf import configure_logger # noqa: E402
+from src.quantization.utils import ConfigParser # noqa: E402
+
+
+log = configure_logger()
+
+
+def cli_argument_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-c', '--config',
+ help='Path to the configuration file in the xml-format.',
+ type=str,
+ required=True,
+ dest='config')
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = cli_argument_parser()
+ try:
+ log.info(f'Parsing the configuration file {args.config}')
+ parser = ConfigParser(args.config)
+ paddle.enable_static()
+ config = parser.parse()
+ exit_code = 0
+ quant_params = PaddleQuantParamReader(log)
+ model_reader = PaddleModelReader(log)
+ for model_quant_config in config:
+ try:
+ data_reader = PaddleDatasetReader(model_quant_config[1]['Dataset'], log)
+ model_reader.add_arguments(model_quant_config[0]['Model'])
+ quant_params.add_arguments(model_quant_config[2]['QuantizationParameters'])
+ proc = PaddleQuantizationProcess(log, model_reader, data_reader, quant_params)
+ proc.quantization_paddle()
+
+ except Exception:
+ log.error(traceback.format_exc())
+ exit_code += 1
+ if exit_code:
+ sys.exit(1)
+ except Exception:
+ log.error(traceback.format_exc())
+ sys.exit(1)
+
+
+if __name__ == '__main__':
+ sys.exit(main() or 0)
diff --git a/tests/smoke_test/configs/quantization_models/resnet-50_PADDLEPADDLE.xml b/tests/smoke_test/configs/quantization_models/resnet-50_PADDLEPADDLE.xml
new file mode 100644
index 000000000..154ea21b3
--- /dev/null
+++ b/tests/smoke_test/configs/quantization_models/resnet-50_PADDLEPADDLE.xml
@@ -0,0 +1,28 @@
+
+
+
+
+ resnet50-paddle
+ ../models_dir/resnet50_paddle/inference
+ ../models_dir/resnet50_paddle
+ resnet50.pdmodel
+ resnet50.pdiparams
+
+
+ test
+ ../test_images/classification_images
+ [123.675, 116.28, 103.53]
+ [58.395, 57.12, 57.375]
+
+ 1
+ 10
+ [224, 224]
+
+
+ [3, 224, 224]
+ inputs
+ res_dir
+ avg
+
+
+
\ No newline at end of file
diff --git a/tests/smoke_test/quantization_smoke/conftest.py b/tests/smoke_test/quantization_smoke/conftest.py
index 416c4242a..a44feff2a 100644
--- a/tests/smoke_test/quantization_smoke/conftest.py
+++ b/tests/smoke_test/quantization_smoke/conftest.py
@@ -6,6 +6,7 @@
from tests.smoke_test.utils import execute_process
from tests.smoke_test.conftest import (SCRIPT_DIR, OUTPUT_DIR, log,
download_models, convert_models)
+from tests.smoke_test.benchmark_smoke.conftest import download_resnet50_paddle
QUANTIZATION_CONFIG_DIR_PATH = Path(SCRIPT_DIR, 'configs', 'quantization_models')
TVM_CONVERTER = Path.joinpath(SCRIPT_DIR.parents[1], 'src/model_converters/tvm_converter/tvm_converter.py')
@@ -43,6 +44,7 @@ def prepare_dl_models(request, overrided_models):
models_per_mark = DL_MODELS
enabled_models = overrided_models if overrided_models else models_per_mark
+ download_resnet50_paddle()
download_models(models_list=enabled_models)
convert_models(models_list=enabled_models)
convert_models_to_tvm()
diff --git a/tests/smoke_test/quantization_smoke/test_quantization_smoke.py b/tests/smoke_test/quantization_smoke/test_quantization_smoke.py
index ba537c140..64b42db4a 100644
--- a/tests/smoke_test/quantization_smoke/test_quantization_smoke.py
+++ b/tests/smoke_test/quantization_smoke/test_quantization_smoke.py
@@ -6,6 +6,8 @@
QUANTIZATION_TFLITE = Path.joinpath(SCRIPT_DIR.parents[1], 'src/quantization/tflite/quantization_tflite.py')
QUANTIZATION_TVM = Path.joinpath(SCRIPT_DIR.parents[1], 'src/quantization/tvm/quantization_tvm.py')
QUANTIZATION_NNCF = Path.joinpath(SCRIPT_DIR.parents[1], 'src/quantization/nncf/quantization_nncf.py')
+QUANTIZATION_PADDLE = Path.joinpath(SCRIPT_DIR.parents[1],
+ 'src/quantization/paddlepaddle/quantization_paddlepaddle.py')
TVM_CONVERTER = Path.joinpath(SCRIPT_DIR.parents[1], 'src/model_converters/tvm_converter/tvm_converter.py')
@@ -17,6 +19,8 @@ def test_smoke_dl_models(test_configuration):
command_line = (f'python3 {QUANTIZATION_TFLITE} -c {test_configuration.config_path}')
elif test_configuration.framework == 'TVM':
command_line = (f'python3 {QUANTIZATION_TVM} -c {test_configuration.config_path}')
+ elif test_configuration.framework == 'PADDLEPADDLE':
+ command_line = (f'python3 {QUANTIZATION_PADDLE} -c {test_configuration.config_path}')
else:
raise Exception(f'Unsupported framework: {test_configuration.framework}')