44import torch
55from torch import nn
66import torch .distributed as dist
7+ from torch .utils .hooks import RemovableHandle
8+
79from xtuner .v1 .module import (
810 RMSNorm ,
911 MultiHeadAttention ,
@@ -35,7 +37,8 @@ class InternalMetricsRecorder:
3537 def __init__ (self , engine : TrainEngine ):
3638 self .model = engine .model
3739 self .intra_layer_micro_batch = engine .intra_layer_micro_batch
38- self .hooks : list [Any ] = []
40+ self .hooks : list [RemovableHandle ] = []
41+ # TODO: refactor with TypeDict
3942 self .metrics : dict [str , dict [str , Any ]] = dict [str , dict [str , Any ]](
4043 weight_rms = dict [str , Any ](),
4144 maxvio = dict [str , Any ](),
@@ -46,14 +49,11 @@ def __init__(self, engine: TrainEngine):
4649 attn_max_logits = dict [str , Any ](),
4750 )
4851
49- def register_weight_rms_hook (self , module , layer_name = None ):
52+ def register_weight_rms_hook (self , module : nn . Module , layer_name : str ):
5053 """
5154 Register weight RMS hook as a pre-forward hook, as at this point, the parameters are should be
5255 all-gathered into current rank.
5356 """
54- if layer_name is None :
55- layer_name = f"layer_{ len (self .weight_rms_dict )} "
56-
5757 def hook (module , args , kwargs = None ):
5858 if layer_name in self .metrics ['weight_rms' ]: # only calculate before the first batch
5959 return
@@ -67,10 +67,13 @@ def hook(module, args, kwargs=None):
6767 rms = torch .sqrt (l2_norm / total_params )
6868 self .metrics ['weight_rms' ][layer_name ] = rms
6969
70- hook_handle = module .register_forward_pre_hook (hook )
70+ hook_handle : RemovableHandle = module .register_forward_pre_hook (hook )
7171 self .hooks .append (hook_handle )
7272
73- def register_attn_extra_info_hook (self , module , layer_name = None ):
73+ def register_attn_extra_info_hook (self , module : nn .Module , layer_name : str ):
74+ """
75+ Register attention extra info hook as a forward hook
76+ """
7477 def hook (module , input , output ):
7578 extra_info = output [1 ]
7679 if extra_info .get ("softmax_lse" , None ) is not None :
@@ -88,7 +91,7 @@ def hook(module, input, output):
8891 prev_logits_max = self .metrics ["attn_max_logits" ][layer_name ]
8992 self .metrics ["attn_max_logits" ][layer_name ] = max (prev_logits_max , extra_info ["attn_logits" ].max ())
9093
91- hook_handle = module .register_forward_hook (hook )
94+ hook_handle : RemovableHandle = module .register_forward_hook (hook )
9295 self .hooks .append (hook_handle )
9396
9497 @torch .no_grad ()
@@ -155,26 +158,26 @@ def get_metrics(self, data_batches: list[ModelItem]):
155158 self .metrics ["maxvio" ]["total" ] = maxvio
156159 self .metrics ["drop_ratio" ]["total" ] = drop_ratio
157160
158- if len ( router_logits_max ) > 0 :
161+ if router_logits_max :
159162 for layer_name , router_logits_list in router_logits_max .items ():
160163 # [bsz/intra_layer_micro_batch, ]
161164 local_router_logits_max = torch .max (torch .stack (router_logits_list ))
162165 dist .all_reduce (local_router_logits_max , op = dist .ReduceOp .MAX )
163166 self .metrics ["router_logits_max" ][layer_name ] = local_router_logits_max .item ()
164167
165- if len ( router_logits_mean ) > 0 :
168+ if router_logits_mean :
166169 for layer_name , router_logits_list in router_logits_mean .items ():
167170 # [bsz/intra_layer_micro_batch, ]
168171 local_router_logits_mean = torch .mean (torch .stack (router_logits_list ))
169172 dist .all_reduce (local_router_logits_mean .div_ (dist .get_world_size ()), op = dist .ReduceOp .SUM )
170173 self .metrics ["router_logits_mean" ][layer_name ] = local_router_logits_mean .item ()
171174
172- if len ( self .metrics ["attn_max_lse" ]) > 0 :
175+ if self .metrics ["attn_max_lse" ]:
173176 for layer_name , local_attn_max_lse in self .metrics ["attn_max_lse" ].items ():
174177 dist .all_reduce (local_attn_max_lse , op = dist .ReduceOp .MAX )
175178 self .metrics ["attn_max_lse" ][layer_name ] = local_attn_max_lse .item ()
176179
177- if len ( self .metrics ["attn_max_logits" ]) > 0 :
180+ if self .metrics ["attn_max_logits" ]:
178181 for layer_name , local_attn_max_logits in self .metrics ["attn_max_logits" ].items ():
179182 dist .all_reduce (local_attn_max_logits , op = dist .ReduceOp .MAX )
180183 self .metrics ["attn_max_logits" ][layer_name ] = local_attn_max_logits .item ()
@@ -187,8 +190,7 @@ def __enter__(self):
187190 self .register_attn_extra_info_hook (module , self ._clean_module_name (name ))
188191 if isinstance (module , RMS_NORM_MONITOR_MODULES ):
189192 self .register_weight_rms_hook (module , self ._clean_module_name (name ))
190- else :
191- pass
193+
192194 return self
193195
194196 def __exit__ (self , exc_type , exc_value , traceback ):
0 commit comments