1717from xtuner .v1 .model import MoE
1818from xtuner .v1 .model .base import ModelItem
1919from xtuner .v1 .engine .train_engine import TrainEngine
20+ from xtuner .v1 .utils .grad_norm import group_tensors_by_device_mesh_and_placements , cal_total_norm
21+
22+ from typing_extensions import TypedDict
23+
24+
25+ class InternalMetrics (TypedDict ):
26+ weight_rms : dict [str , float ]
27+ maxvio : dict [str , float ]
28+ drop_ratio : dict [str , float ]
29+ router_logits_max : dict [str , float ]
30+ router_logits_mean : dict [str , float ]
31+ attn_max_lse : dict [str , float ]
32+ attn_max_logits : dict [str , float ]
33+
2034
2135RMS_NORM_MONITOR_MODULES = (
2236 nn .Embedding ,
@@ -34,36 +48,32 @@ def __init__(self, engine: TrainEngine):
3448 self .intra_layer_micro_batch = engine .intra_layer_micro_batch
3549 self .hooks : list [RemovableHandle ] = []
3650 # TODO: refactor with TypeDict
37- self .metrics : dict [str , dict [str , Any ]] = dict [str , dict [str , Any ]](
38- weight_rms = dict [str , Any ](),
39- maxvio = dict [str , Any ](),
40- drop_ratio = dict [str , Any ](),
41- router_logits_max = dict [str , Any ](),
42- router_logits_mean = dict [str , Any ](),
43- attn_max_lse = dict [str , Any ](),
44- attn_max_logits = dict [str , Any ](),
45- )
46-
47- def register_weight_rms_hook (self , module : nn .Module , layer_name : str ):
48- """
49- Register weight RMS hook as a pre-forward hook, as at this point, the parameters are should be
50- all-gathered into current rank.
51- """
52- def hook (module , args , kwargs = None ):
53- if layer_name in self .metrics ['weight_rms' ]: # only calculate before the first batch
54- return
55- l2_norm = 0.0
56- total_params = 0
57- for param in module .parameters ():
58- if param .requires_grad :
59- l2_norm += torch .norm (param .detach ().float (), p = 2 ) ** 2
60- total_params += param .numel ()
61- if total_params > 0 :
62- rms = torch .sqrt (l2_norm / total_params )
63- self .metrics ['weight_rms' ][layer_name ] = rms
64-
65- hook_handle : RemovableHandle = module .register_forward_pre_hook (hook )
66- self .hooks .append (hook_handle )
51+ self .metrics : InternalMetrics = {
52+ "weight_rms" : {},
53+ "maxvio" : {},
54+ "drop_ratio" : {},
55+ "router_logits_max" : {},
56+ "router_logits_mean" : {},
57+ "attn_max_lse" : {},
58+ "attn_max_logits" : {},
59+ }
60+ self .attn_max_lse : dict [str , torch .Tensor ] = {}
61+ self .attn_max_logits : dict [str , torch .Tensor ] = {}
62+
63+ def calculate_module_weight_rms (self , module : nn .Module , layer_name : str , dtype : torch .dtype = torch .float32 ):
64+ all_params = [param for param in module .parameters () if param .requires_grad ]
65+ if not all_params :
66+ return
67+ grouped_params = group_tensors_by_device_mesh_and_placements (all_params )
68+ total_norms = []
69+ total_numel = 0
70+ for params in grouped_params .values ():
71+ total_norm = cal_total_norm (params , norm_type = 2.0 , foreach = True , dtype = dtype )
72+ total_norms .append (total_norm )
73+ total_numel += sum (p .numel () for p in params )
74+ param_l2_norm = torch .linalg .vector_norm (torch .stack (total_norms ), ord = 2.0 , dtype = dtype )
75+ param_rms = param_l2_norm / total_numel ** 0.5
76+ self .metrics ['weight_rms' ][layer_name ] = param_rms .item ()
6777
6878 def register_attn_extra_info_hook (self , module : nn .Module , layer_name : str ):
6979 """
@@ -72,19 +82,19 @@ def register_attn_extra_info_hook(self, module: nn.Module, layer_name: str):
7282 def hook (module , input , output ):
7383 extra_info = output [1 ]
7484 if extra_info .get ("softmax_lse" , None ) is not None :
75- if layer_name not in self .metrics [ " attn_max_lse" ] :
85+ if layer_name not in self .attn_max_lse :
7686 # original shape: [n_head, seq]
77- self .metrics [ " attn_max_lse" ] [layer_name ] = extra_info ["softmax_lse" ].max ()
87+ self .attn_max_lse [layer_name ] = extra_info ["softmax_lse" ].max ()
7888 else :
79- prev_lse_max = self .metrics [ " attn_max_lse" ] [layer_name ]
80- self .metrics [ " attn_max_lse" ] [layer_name ] = max (prev_lse_max , extra_info ["softmax_lse" ].max ())
89+ prev_lse_max = self .attn_max_lse [layer_name ]
90+ self .attn_max_lse [layer_name ] = max (prev_lse_max , extra_info ["softmax_lse" ].max ())
8191 if extra_info .get ("attn_logits" , None ) is not None :
82- if layer_name not in self .metrics [ " attn_max_logits" ] :
92+ if layer_name not in self .attn_max_logits :
8393 # original shape: [b, n_head, seq, seq]
84- self .metrics [ " attn_max_logits" ] [layer_name ] = extra_info ["attn_logits" ].max ()
94+ self .attn_max_logits [layer_name ] = extra_info ["attn_logits" ].max ()
8595 else :
86- prev_logits_max = self .metrics [ " attn_max_logits" ] [layer_name ]
87- self .metrics [ " attn_max_logits" ] [layer_name ] = max (prev_logits_max , extra_info ["attn_logits" ].max ())
96+ prev_logits_max = self .attn_max_logits [layer_name ]
97+ self .attn_max_logits [layer_name ] = max (prev_logits_max , extra_info ["attn_logits" ].max ())
8898
8999 hook_handle : RemovableHandle = module .register_forward_hook (hook )
90100 self .hooks .append (hook_handle )
@@ -150,8 +160,8 @@ def get_metrics(self, data_batches: list[ModelItem]):
150160 {f"layer{ idx } " : maxvio_all_layers [idx ].item () for idx in range (max_load_i .shape [0 ])}
151161 )
152162 maxvio = maxvio_all_layers .mean ()
153- self .metrics ["maxvio" ]["total" ] = maxvio
154- self .metrics ["drop_ratio" ]["total" ] = drop_ratio
163+ self .metrics ["maxvio" ]["total" ] = maxvio . item ()
164+ self .metrics ["drop_ratio" ]["total" ] = drop_ratio . item ()
155165
156166 if router_logits_max :
157167 for layer_name , router_logits_list in router_logits_max .items ():
@@ -168,12 +178,12 @@ def get_metrics(self, data_batches: list[ModelItem]):
168178 self .metrics ["router_logits_mean" ][layer_name ] = local_router_logits_mean .item ()
169179
170180 if self .metrics ["attn_max_lse" ]:
171- for layer_name , local_attn_max_lse in self .metrics [ " attn_max_lse" ] .items ():
181+ for layer_name , local_attn_max_lse in self .attn_max_lse .items ():
172182 dist .all_reduce (local_attn_max_lse , op = dist .ReduceOp .MAX )
173183 self .metrics ["attn_max_lse" ][layer_name ] = local_attn_max_lse .item ()
174184
175- if self .metrics [ " attn_max_logits" ] :
176- for layer_name , local_attn_max_logits in self .metrics [ " attn_max_logits" ] .items ():
185+ if self .attn_max_logits :
186+ for layer_name , local_attn_max_logits in self .attn_max_logits .items ():
177187 dist .all_reduce (local_attn_max_logits , op = dist .ReduceOp .MAX )
178188 self .metrics ["attn_max_logits" ][layer_name ] = local_attn_max_logits .item ()
179189
@@ -184,7 +194,7 @@ def __enter__(self):
184194 if isinstance (module , ATTENTION_CLS ):
185195 self .register_attn_extra_info_hook (module , self ._clean_module_name (name ))
186196 if isinstance (module , RMS_NORM_MONITOR_MODULES ):
187- self .register_weight_rms_hook (module , self ._clean_module_name (name ))
197+ self .calculate_module_weight_rms (module , self ._clean_module_name (name ), dtype = torch . float32 )
188198
189199 return self
190200
0 commit comments