Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Question about load_calibration_config function of the QAT quantizer #4396

Open
@Lycan1003

Description

@Lycan1003
    def load_calibration_config(self, calibration_config):
        modules_to_compress = self.get_modules_to_compress()
        for layer, _ in modules_to_compress:
            name, module = layer.name, layer.module
            if name not in calibration_config:
                if module.layer_quant_setting.weight or module.layer_quant_setting.input or module.layer_quant_setting.output:
                    logger.warning(f"Can not find module {name}'s parameter in input config.")
                continue
            if module.layer_quant_setting.weight:
                assert calibration_config[name]['weight_bits'] == module.layer_quant_setting.weight.bits, \
                    f"weight bits of module {name} fail to match"
            if module.layer_quant_setting.input:
                assert calibration_config[name]['input_bits'] == module.layer_quant_setting.input.bits, \
                    f"input bits of module {name} fail to match"
                module.tracked_min_input.data = torch.tensor([calibration_config[name]['tracked_min_input']])
                module.tracked_max_input.data = torch.tensor([calibration_config[name]['tracked_max_input']])
            if module.layer_quant_setting.output:
                assert calibration_config[name]['output_bits'] == module.layer_quant_setting.output.bits, \
                    f"output bits of module {name} fail to match"
                module.tracked_min_output.data = torch.tensor([calibration_config[name]['tracked_min_output']])
                module.tracked_max_output.data = torch.tensor([calibration_config[name]['tracked_max_output']])

The above code is from line 418 to 438 of qat_quantizer.py. As we can see, it can load tracked_min and tracked_max of input and output, but for the weight, it only checks the weight_bit match or not to the config list, and then do nothing.

In my understanding, maybe it is reasonable to load weight_scale and weight_zero_point from calibration_config show like following code?

            if module.layer_quant_setting.weight:
                assert calibration_config[name]['weight_bits'] == module.layer_quant_setting.weight.bits, \
                    f"weight bits of module {name} fail to match"
                module.weight_scale = calibration_config[name]['weight_scale'].data
                module.weight_zero_point = calibration_config[name]['weight_zero_point'].data

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions