Skip to content

Commit 55c304b

Browse files
PannenetsFfanyunqian
and
fanyunqian
authored
[Fix] AdaRound load state dict missing alpha. (#140)
* [Fix] AdaRound load state dict missing alpha. * [Feat] per channel hooks for buffer Co-authored-by: fanyunqian <[email protected]>
1 parent d3292fa commit 55c304b

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

mqbench/fake_quantize/adaround_quantizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(self, observer, **observer_kwargs):
4848
self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
4949
self.register_buffer('zero_point', torch.tensor([0], dtype=torch.int))
5050
self.adaround = False
51-
self.load_state_dict_hook = PerChannelLoadHook(self)
51+
self.load_state_dict_hook = PerChannelLoadHook(self, hook_param=['scale', 'zero_point', 'alpha'])
5252

5353
def init(self, weight_tensor: torch.Tensor, round_mode='learned_hard_sigmoid', ):
5454
self.adaround = True

mqbench/utils/hook.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,25 @@ def __call__(self, module, input_batch, output_batch):
3131

3232

3333
class PerChannelLoadHook:
34-
def __init__(self, module):
34+
def __init__(self, module, hook_param=["scale", "zero_point"]):
3535
self.hook = module._register_load_state_dict_pre_hook(partial(self.hook_fn, module=module))
36+
self.hook_param = hook_param
3637

3738
def hook_fn(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs,
3839
module):
3940
if module.ch_axis == -1:
4041
# no per-channel parameters
4142
return
4243
for module_key, param in module._parameters.items():
43-
if module_key not in ["scale", "zero_point"]:
44+
if module_key not in self.hook_param:
45+
continue
46+
candidate = prefix + module_key
47+
if candidate in state_dict:
48+
input_param = state_dict[candidate]
49+
if param.shape != input_param.shape:
50+
param.data = torch.ones_like(input_param, dtype=param.dtype, device=param.device)
51+
for module_key, param in module._buffers.items():
52+
if module_key not in self.hook_param:
4453
continue
4554
candidate = prefix + module_key
4655
if candidate in state_dict:

0 commit comments

Comments
 (0)