@@ -70,24 +70,24 @@ def calculate_module_weight_rms(self, module: nn.Module, layer_name: str, dtype:
7070 param_rms = param_l2_norm / total_numel ** 0.5
7171 self .metrics ["weight_rms" ][layer_name ] = param_rms .item ()
7272
73- def register_attn_extra_info_hook (self , module : nn .Module , layer_name : str ):
73+ def register_attn_extra_info_hook (self , module : nn .Module ):
7474 """Register attention extra info hook as a forward hook"""
7575 def hook (module , input , output ):
7676 extra_info = output [1 ]
7777 if extra_info .get ("softmax_lse" , None ) is not None :
78- if layer_name not in ATTN_MAX_LSE :
78+ if module . name not in ATTN_MAX_LSE :
7979 # original shape: [n_head, seq]
80- ATTN_MAX_LSE [layer_name ] = extra_info ["softmax_lse" ].max ()
80+ ATTN_MAX_LSE [module . name ] = extra_info ["softmax_lse" ].max ()
8181 else :
82- prev_lse_max = ATTN_MAX_LSE [layer_name ]
83- ATTN_MAX_LSE [layer_name ] = max (prev_lse_max , extra_info ["softmax_lse" ].max ())
82+ prev_lse_max = ATTN_MAX_LSE [module . name ]
83+ ATTN_MAX_LSE [module . name ] = max (prev_lse_max , extra_info ["softmax_lse" ].max ())
8484 if extra_info .get ("attn_logits" , None ) is not None :
85- if layer_name not in ATTN_MAX_LOGITS :
85+ if module . name not in ATTN_MAX_LOGITS :
8686 # original shape: [b, n_head, seq, seq]
87- ATTN_MAX_LOGITS [layer_name ] = extra_info ["attn_logits" ].max ()
87+ ATTN_MAX_LOGITS [module . name ] = extra_info ["attn_logits" ].max ()
8888 else :
89- prev_logits_max = ATTN_MAX_LOGITS [layer_name ]
90- ATTN_MAX_LOGITS [layer_name ] = max (prev_logits_max , extra_info ["attn_logits" ].max ())
89+ prev_logits_max = ATTN_MAX_LOGITS [module . name ]
90+ ATTN_MAX_LOGITS [module . name ] = max (prev_logits_max , extra_info ["attn_logits" ].max ())
9191
9292 hook_handle : RemovableHandle = module .register_forward_hook (hook )
9393 self .hooks .append (hook_handle )
@@ -187,7 +187,7 @@ def get_metrics(self, data_batches: list[ModelItem]):
187187 def __enter__ (self ):
188188 for name , module in self .model .named_modules ():
189189 if isinstance (module , ATTENTION_CLS ):
190- self .register_attn_extra_info_hook (module , self . _clean_module_name ( name ) )
190+ self .register_attn_extra_info_hook (module )
191191 if isinstance (module , RMS_NORM_MONITOR_MODULES ):
192192 self .calculate_module_weight_rms (module , self ._clean_module_name (name ), dtype = torch .float32 )
193193 return self
0 commit comments