1313from xtuner .v1 .model .base import BaseModel as XTunerBaseModel
1414from xtuner .v1 .model .base import ModelItem
1515from xtuner .v1 .module import LMHead , MHAConfig , MLAConfig , MultiHeadAttention , MultiLatentAttention
16+ from xtuner .v1 .module .attention .gated_deltanet import FusedRMSNormGated
1617from xtuner .v1 .module .decoder_layer .dense_decoder_layer import DenseDecoderLayer
1718from xtuner .v1 .module .decoder_layer .moe_decoder_layer import MoEDecoderLayer
1819from xtuner .v1 .utils .device import get_device
3738
3839class InternalMetrics (TypedDict , total = False ):
3940 weight_rms : dict [str , float ]
41+ weight_min : dict [str , float ]
42+ weight_max : dict [str , float ]
4043 maxvio : dict [str , float ]
4144 drop_ratio : dict [str , float ]
4245 router_logits_max : dict [str , float ]
@@ -90,6 +93,8 @@ def _init_metrics_dict(self) -> InternalMetrics:
9093
9194 if self .internal_metrics_cfg .monitor_weights_rms_norm :
9295 metrics ["weight_rms" ] = {}
96+ metrics ["weight_min" ] = {}
97+ metrics ["weight_max" ] = {}
9398
9499 if self .internal_metrics_cfg .monitor_attn_logits_stats :
95100 attn_cfg : MHAConfig | MLAConfig = self .model .config .attention # type: ignore[attr-defined]
@@ -153,6 +158,44 @@ def calculate_module_weight_rms(self, module: nn.Module, layer_name: str, dtype:
153158
154159 self .metrics ["weight_rms" ][layer_name ] = param_rms .item ()
155160
161+ @torch .no_grad ()
162+ def calculate_module_weight_min_max (self , module : nn .Module | nn .Parameter | torch .Tensor , layer_name : str ):
163+ """Calculate the min and max of the module's parameters."""
164+ self ._check_closed ()
165+
166+ if "weight_min" not in self .metrics or "weight_max" not in self .metrics :
167+ return
168+
169+ if isinstance (module , nn .Module ):
170+ all_params = [param .data for param in module .parameters () if param .requires_grad ]
171+ else :
172+ all_params = [module .data ]
173+
174+ if not all_params :
175+ return
176+
177+ # Handle DTensor - convert to local tensors
178+ from torch .distributed .tensor import DTensor
179+
180+ local_params = []
181+ for param in all_params :
182+ if isinstance (param , DTensor ):
183+ local_params .append (param .to_local ())
184+ else :
185+ local_params .append (param )
186+
187+ # Calculate local min/max
188+ local_min = torch .min (torch .stack ([p .min () for p in local_params ]))
189+ local_max = torch .max (torch .stack ([p .max () for p in local_params ]))
190+
191+ # All-reduce across ranks
192+ if dist .is_initialized () and dist .get_world_size () > 1 :
193+ dist .all_reduce (local_min , op = dist .ReduceOp .MIN )
194+ dist .all_reduce (local_max , op = dist .ReduceOp .MAX )
195+
196+ self .metrics ["weight_min" ][layer_name ] = local_min .item ()
197+ self .metrics ["weight_max" ][layer_name ] = local_max .item ()
198+
156199 def register_attn_output_hook (self , module : nn .Module ):
157200 """Register attention output hook as a forward hook."""
158201 self ._check_closed ()
@@ -179,6 +222,14 @@ def pop_metrics(self, data_batches: list[ModelItem]):
179222 if self .internal_metrics_cfg .monitor_weights_rms_norm and isinstance (module , RMS_NORM_MONITOR_MODULES ):
180223 self .calculate_module_weight_rms (module , self ._clean_module_name (name ), dtype = torch .float32 )
181224
225+ if self .internal_metrics_cfg .monitor_weights_rms_norm and isinstance (module , FusedRMSNormGated ):
226+ self .calculate_module_weight_min_max (module , self ._clean_module_name (name ))
227+
228+ if self .internal_metrics_cfg .monitor_weights_rms_norm and hasattr (module , "A_log" ):
229+ self .calculate_module_weight_min_max (module .A_log , f"{ self ._clean_module_name (name )} .A_log" )
230+
231+
232+
182233 additional_kwargs = {}
183234 if self .internal_metrics_cfg .monitor_moe_router_logits_stats and isinstance (self .model , MoE ):
184235 # for MoE model, add additional kwargs to return necessary stats
0 commit comments