Skip to content

Commit 30a35f1

Browse files
committed
[Enhance] Add type hint and fix typos and code styles
1 parent d538d8d commit 30a35f1

1 file changed

Lines changed: 16 additions & 14 deletions

File tree

xtuner/v1/utils/internal_metrics.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch
55
from torch import nn
66
import torch.distributed as dist
7+
from torch.utils.hooks import RemovableHandle
8+
79
from xtuner.v1.module import (
810
RMSNorm,
911
MultiHeadAttention,
@@ -35,7 +37,8 @@ class InternalMetricsRecorder:
3537
def __init__(self, engine: TrainEngine):
3638
self.model = engine.model
3739
self.intra_layer_micro_batch = engine.intra_layer_micro_batch
38-
self.hooks: list[Any] = []
40+
self.hooks: list[RemovableHandle] = []
41+
# TODO: refactor with TypeDict
3942
self.metrics: dict[str, dict[str, Any]] = dict[str, dict[str, Any]](
4043
weight_rms=dict[str, Any](),
4144
maxvio=dict[str, Any](),
@@ -46,14 +49,11 @@ def __init__(self, engine: TrainEngine):
4649
attn_max_logits=dict[str, Any](),
4750
)
4851

49-
def register_weight_rms_hook(self, module, layer_name=None):
52+
def register_weight_rms_hook(self, module: nn.Module, layer_name: str):
5053
"""
5154
Register weight RMS hook as a pre-forward hook, as at this point, the parameters are should be
5255
all-gathered into current rank.
5356
"""
54-
if layer_name is None:
55-
layer_name = f"layer_{len(self.weight_rms_dict)}"
56-
5757
def hook(module, args, kwargs=None):
5858
if layer_name in self.metrics['weight_rms']: # only calculate before the first batch
5959
return
@@ -67,10 +67,13 @@ def hook(module, args, kwargs=None):
6767
rms = torch.sqrt(l2_norm / total_params)
6868
self.metrics['weight_rms'][layer_name] = rms
6969

70-
hook_handle = module.register_forward_pre_hook(hook)
70+
hook_handle: RemovableHandle = module.register_forward_pre_hook(hook)
7171
self.hooks.append(hook_handle)
7272

73-
def register_attn_extra_info_hook(self, module, layer_name=None):
73+
def register_attn_extra_info_hook(self, module: nn.Module, layer_name: str):
74+
"""
75+
Register attention extra info hook as a forward hook
76+
"""
7477
def hook(module, input, output):
7578
extra_info = output[1]
7679
if extra_info.get("softmax_lse", None) is not None:
@@ -88,7 +91,7 @@ def hook(module, input, output):
8891
prev_logits_max = self.metrics["attn_max_logits"][layer_name]
8992
self.metrics["attn_max_logits"][layer_name] = max(prev_logits_max, extra_info["attn_logits"].max())
9093

91-
hook_handle = module.register_forward_hook(hook)
94+
hook_handle: RemovableHandle = module.register_forward_hook(hook)
9295
self.hooks.append(hook_handle)
9396

9497
@torch.no_grad()
@@ -155,26 +158,26 @@ def get_metrics(self, data_batches: list[ModelItem]):
155158
self.metrics["maxvio"]["total"] = maxvio
156159
self.metrics["drop_ratio"]["total"] = drop_ratio
157160

158-
if len(router_logits_max) > 0:
161+
if router_logits_max:
159162
for layer_name, router_logits_list in router_logits_max.items():
160163
# [bsz/intra_layer_micro_batch, ]
161164
local_router_logits_max = torch.max(torch.stack(router_logits_list))
162165
dist.all_reduce(local_router_logits_max, op=dist.ReduceOp.MAX)
163166
self.metrics["router_logits_max"][layer_name] = local_router_logits_max.item()
164167

165-
if len(router_logits_mean) > 0:
168+
if router_logits_mean:
166169
for layer_name, router_logits_list in router_logits_mean.items():
167170
# [bsz/intra_layer_micro_batch, ]
168171
local_router_logits_mean = torch.mean(torch.stack(router_logits_list))
169172
dist.all_reduce(local_router_logits_mean.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
170173
self.metrics["router_logits_mean"][layer_name] = local_router_logits_mean.item()
171174

172-
if len(self.metrics["attn_max_lse"]) > 0:
175+
if self.metrics["attn_max_lse"]:
173176
for layer_name, local_attn_max_lse in self.metrics["attn_max_lse"].items():
174177
dist.all_reduce(local_attn_max_lse, op=dist.ReduceOp.MAX)
175178
self.metrics["attn_max_lse"][layer_name] = local_attn_max_lse.item()
176179

177-
if len(self.metrics["attn_max_logits"]) > 0:
180+
if self.metrics["attn_max_logits"]:
178181
for layer_name, local_attn_max_logits in self.metrics["attn_max_logits"].items():
179182
dist.all_reduce(local_attn_max_logits, op=dist.ReduceOp.MAX)
180183
self.metrics["attn_max_logits"][layer_name] = local_attn_max_logits.item()
@@ -187,8 +190,7 @@ def __enter__(self):
187190
self.register_attn_extra_info_hook(module, self._clean_module_name(name))
188191
if isinstance(module, RMS_NORM_MONITOR_MODULES):
189192
self.register_weight_rms_hook(module, self._clean_module_name(name))
190-
else:
191-
pass
193+
192194
return self
193195

194196
def __exit__(self, exc_type, exc_value, traceback):

0 commit comments

Comments
 (0)