This repository was archived by the owner on Sep 18, 2024. It is now read-only.
This repository was archived by the owner on Sep 18, 2024. It is now read-only.
Question about load_calibration_config function of the QAT quantizer #4396
Open
Description
def load_calibration_config(self, calibration_config):
modules_to_compress = self.get_modules_to_compress()
for layer, _ in modules_to_compress:
name, module = layer.name, layer.module
if name not in calibration_config:
if module.layer_quant_setting.weight or module.layer_quant_setting.input or module.layer_quant_setting.output:
logger.warning(f"Can not find module {name}'s parameter in input config.")
continue
if module.layer_quant_setting.weight:
assert calibration_config[name]['weight_bits'] == module.layer_quant_setting.weight.bits, \
f"weight bits of module {name} fail to match"
if module.layer_quant_setting.input:
assert calibration_config[name]['input_bits'] == module.layer_quant_setting.input.bits, \
f"input bits of module {name} fail to match"
module.tracked_min_input.data = torch.tensor([calibration_config[name]['tracked_min_input']])
module.tracked_max_input.data = torch.tensor([calibration_config[name]['tracked_max_input']])
if module.layer_quant_setting.output:
assert calibration_config[name]['output_bits'] == module.layer_quant_setting.output.bits, \
f"output bits of module {name} fail to match"
module.tracked_min_output.data = torch.tensor([calibration_config[name]['tracked_min_output']])
module.tracked_max_output.data = torch.tensor([calibration_config[name]['tracked_max_output']])
The above code is from line 418 to 438 of qat_quantizer.py. As we can see, it can load tracked_min and tracked_max of input and output, but for the weight, it only checks the weight_bit match or not to the config list, and then do nothing.
In my understanding, maybe it is reasonable to load weight_scale and weight_zero_point from calibration_config show like following code?
if module.layer_quant_setting.weight:
assert calibration_config[name]['weight_bits'] == module.layer_quant_setting.weight.bits, \
f"weight bits of module {name} fail to match"
module.weight_scale = calibration_config[name]['weight_scale'].data
module.weight_zero_point = calibration_config[name]['weight_zero_point'].data