Skip to content

Commit b3018e9

Browse files
committed
[Enhance] Support internal metrics for gdn A_log and norm
1 parent b0fdc8d commit b3018e9

1 file changed

Lines changed: 51 additions & 0 deletions

File tree

xtuner/v1/utils/internal_metrics.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from xtuner.v1.model.base import BaseModel as XTunerBaseModel
1414
from xtuner.v1.model.base import ModelItem
1515
from xtuner.v1.module import LMHead, MHAConfig, MLAConfig, MultiHeadAttention, MultiLatentAttention
16+
from xtuner.v1.module.attention.gated_deltanet import FusedRMSNormGated
1617
from xtuner.v1.module.decoder_layer.dense_decoder_layer import DenseDecoderLayer
1718
from xtuner.v1.module.decoder_layer.moe_decoder_layer import MoEDecoderLayer
1819
from xtuner.v1.utils.device import get_device
@@ -37,6 +38,8 @@
3738

3839
class InternalMetrics(TypedDict, total=False):
3940
weight_rms: dict[str, float]
41+
weight_min: dict[str, float]
42+
weight_max: dict[str, float]
4043
maxvio: dict[str, float]
4144
drop_ratio: dict[str, float]
4245
router_logits_max: dict[str, float]
@@ -90,6 +93,8 @@ def _init_metrics_dict(self) -> InternalMetrics:
9093

9194
if self.internal_metrics_cfg.monitor_weights_rms_norm:
9295
metrics["weight_rms"] = {}
96+
metrics["weight_min"] = {}
97+
metrics["weight_max"] = {}
9398

9499
if self.internal_metrics_cfg.monitor_attn_logits_stats:
95100
attn_cfg: MHAConfig | MLAConfig = self.model.config.attention # type: ignore[attr-defined]
@@ -153,6 +158,44 @@ def calculate_module_weight_rms(self, module: nn.Module, layer_name: str, dtype:
153158

154159
self.metrics["weight_rms"][layer_name] = param_rms.item()
155160

161+
@torch.no_grad()
162+
def calculate_module_weight_min_max(self, module: nn.Module | nn.Parameter | torch.Tensor, layer_name: str):
163+
"""Calculate the min and max of the module's parameters."""
164+
self._check_closed()
165+
166+
if "weight_min" not in self.metrics or "weight_max" not in self.metrics:
167+
return
168+
169+
if isinstance(module, nn.Module):
170+
all_params = [param.data for param in module.parameters() if param.requires_grad]
171+
else:
172+
all_params = [module.data]
173+
174+
if not all_params:
175+
return
176+
177+
# Handle DTensor - convert to local tensors
178+
from torch.distributed.tensor import DTensor
179+
180+
local_params = []
181+
for param in all_params:
182+
if isinstance(param, DTensor):
183+
local_params.append(param.to_local())
184+
else:
185+
local_params.append(param)
186+
187+
# Calculate local min/max
188+
local_min = torch.min(torch.stack([p.min() for p in local_params]))
189+
local_max = torch.max(torch.stack([p.max() for p in local_params]))
190+
191+
# All-reduce across ranks
192+
if dist.is_initialized() and dist.get_world_size() > 1:
193+
dist.all_reduce(local_min, op=dist.ReduceOp.MIN)
194+
dist.all_reduce(local_max, op=dist.ReduceOp.MAX)
195+
196+
self.metrics["weight_min"][layer_name] = local_min.item()
197+
self.metrics["weight_max"][layer_name] = local_max.item()
198+
156199
def register_attn_output_hook(self, module: nn.Module):
157200
"""Register attention output hook as a forward hook."""
158201
self._check_closed()
@@ -179,6 +222,14 @@ def pop_metrics(self, data_batches: list[ModelItem]):
179222
if self.internal_metrics_cfg.monitor_weights_rms_norm and isinstance(module, RMS_NORM_MONITOR_MODULES):
180223
self.calculate_module_weight_rms(module, self._clean_module_name(name), dtype=torch.float32)
181224

225+
if self.internal_metrics_cfg.monitor_weights_rms_norm and isinstance(module, FusedRMSNormGated):
226+
self.calculate_module_weight_min_max(module, self._clean_module_name(name))
227+
228+
if self.internal_metrics_cfg.monitor_weights_rms_norm and hasattr(module, "A_log"):
229+
self.calculate_module_weight_min_max(module.A_log, f"{self._clean_module_name(name)}.A_log")
230+
231+
232+
182233
additional_kwargs = {}
183234
if self.internal_metrics_cfg.monitor_moe_router_logits_stats and isinstance(self.model, MoE):
184235
# for MoE model, add additional kwargs to return necessary stats

0 commit comments

Comments
 (0)