@@ -31,16 +31,25 @@ def __call__(self, module, input_batch, output_batch):
31
31
32
32
33
33
class PerChannelLoadHook :
34
- def __init__ (self , module ):
34
+ def __init__ (self , module , hook_param = [ "scale" , "zero_point" ] ):
35
35
self .hook = module ._register_load_state_dict_pre_hook (partial (self .hook_fn , module = module ))
36
+ self .hook_param = hook_param
36
37
37
38
def hook_fn (self , state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs ,
38
39
module ):
39
40
if module .ch_axis == - 1 :
40
41
# no per-channel parameters
41
42
return
42
43
for module_key , param in module ._parameters .items ():
43
- if module_key not in ["scale" , "zero_point" ]:
44
+ if module_key not in self .hook_param :
45
+ continue
46
+ candidate = prefix + module_key
47
+ if candidate in state_dict :
48
+ input_param = state_dict [candidate ]
49
+ if param .shape != input_param .shape :
50
+ param .data = torch .ones_like (input_param , dtype = param .dtype , device = param .device )
51
+ for module_key , param in module ._buffers .items ():
52
+ if module_key not in self .hook_param :
44
53
continue
45
54
candidate = prefix + module_key
46
55
if candidate in state_dict :
0 commit comments