Skip to content

Commit 14aa604

Browse files
committed
[Enhance] drop rms_norm hook impl. to avoid precision problem (fp8)
[Enhance] refactor internal metrics to use TypedDict
1 parent ebf8d26 commit 14aa604

4 files changed

Lines changed: 69 additions & 108 deletions

File tree

xtuner/v1/float8/float8_ops.py

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -165,56 +165,4 @@ def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None):
165165
args[0]._orig_dtype,
166166
args[0]._scaling_granularity,
167167
args[0]._group_size,
168-
)
169-
170-
171-
@implements(
172-
[
173-
aten._to_copy.default,
174-
]
175-
)
176-
def float8_to_copy(aten_op, args, kwargs=None):
177-
if kwargs is None:
178-
kwargs = {}
179-
180-
dtype = kwargs.get("dtype", torch.float32)
181-
182-
tensor = args[0]
183-
scaling_granularity = tensor._scaling_granularity
184-
185-
dequantized: torch.Tensor = None
186-
187-
if scaling_granularity == ScalingGranularity.TENSORWISE:
188-
dequantized = tensor._data.to(torch.float32) * tensor._scale.to(torch.float32)
189-
190-
elif scaling_granularity == ScalingGranularity.BLOCKWISE:
191-
from xtuner.v1.float8.triton_kernels import per_block_dequant_gemm
192-
193-
if tensor._data.ndim == 2:
194-
dequantized = per_block_dequant_gemm(tensor._data, tensor._scale, block_size=tensor._group_size)
195-
196-
else:
197-
raise NotImplementedError(
198-
f"{aten_op} with {scaling_granularity} scaling granularity is not implemented. "
199-
)
200-
elif scaling_granularity == ScalingGranularity.TILEWISE:
201-
# For tilewise, scale is per-tile (1x128)
202-
original_shape = tensor._data.shape
203-
data_flat = tensor._data.view(-1, original_shape[-1])
204-
scale_flat = tensor._scale.view(-1, tensor._scale.shape[-1])
205-
206-
# Expand scale to match data: each scale applies to group_size elements
207-
group_size = tensor._group_size
208-
num_groups = data_flat.shape[-1] // group_size
209-
scale_expanded = scale_flat[:, :num_groups].unsqueeze(-1).expand(-1, -1, group_size)
210-
scale_expanded = scale_expanded.contiguous().view(-1, data_flat.shape[-1])
211-
212-
# Dequantize: data * scale
213-
dequantized_flat = data_flat.to(torch.float32) * scale_expanded.to(torch.float32)
214-
dequantized = dequantized_flat.view(*original_shape)
215-
else:
216-
raise NotImplementedError(
217-
f"{aten_op} with {scaling_granularity} scaling granularity is not supported. "
218-
)
219-
220-
return dequantized.to(dtype)
168+
)

xtuner/v1/train/trainer.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
log_format,
4848
record_git_info,
4949
InternalMetricsRecorder,
50+
InternalMetrics,
5051
)
5152
from xtuner.v1.utils.device import get_device, get_torch_device_module
5253

@@ -609,7 +610,7 @@ def fit(self):
609610
self._exp_tracker.close()
610611
self.logger.info(f"Training finished in {time.time() - train_begin:.2f} seconds")
611612

612-
def _maybe_check_model_internal_metrics(self, data_batches: list[ModelItem]) -> dict[str, float] | None:
613+
def _maybe_check_model_internal_metrics(self, data_batches: list[ModelItem]) -> InternalMetrics | None:
613614
if self._internal_metrics_interval is None:
614615
return None
615616

@@ -1183,7 +1184,7 @@ def _log_step(
11831184
if internal_metrics is None:
11841185
internal_metrics = {}
11851186
else:
1186-
internal_metrics = _flatten_dict(internal_metrics)
1187+
internal_metrics = _flatten_nested_metrics(internal_metrics)
11871188

11881189
self.logger.info(
11891190
f"Epoch {self._cur_epoch} Step {self.cur_step}/{self.total_step} data_time: {data_time:.4f} lr: {lr:.6e} time: {step_time:.4f} "
@@ -1437,14 +1438,15 @@ def _setup_env(self):
14371438
logger.info(log_str)
14381439

14391440

1440-
def _flatten_dict(d: dict, parent_key: str = '', sep: str = '/') -> dict:
1441+
def _flatten_nested_metrics(metrics: InternalMetrics, sep: str = '/') -> dict:
14411442
items = []
1442-
for k, v in d.items():
1443-
new_key = f"{parent_key}{sep}{k}" if parent_key else k
1444-
if isinstance(v, dict):
1445-
items.extend(_flatten_dict(v, new_key, sep=sep).items())
1446-
elif isinstance(v, torch.Tensor):
1447-
items.append((new_key, v.item()))
1443+
for name, sub_metrics in metrics.items():
1444+
if isinstance(sub_metrics, dict):
1445+
for k, v in sub_metrics.items():
1446+
if isinstance(v, (float, int)):
1447+
items.append((f"{name}{sep}{k}", v))
1448+
else:
1449+
raise ValueError(f"Unsupported metric value type: expected float or int, but got {type(v)}")
14481450
else:
1449-
items.append((new_key, v))
1451+
raise ValueError(f"Unsupported metric type for internal metrics: expected dict, but got {type(sub_metrics)}")
14501452
return dict(items)

xtuner/v1/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .type_helper import copy_method_signature, copy_signature
1515
from .update_weights_utils import monkey_unpatch_torch_reductions
1616

17-
from .internal_metrics import InternalMetricsRecorder
17+
from .internal_metrics import InternalMetricsRecorder, InternalMetrics
1818

1919
IGNORE_INDEX = -100
2020

@@ -47,4 +47,5 @@
4747
"IGNORE_INDEX",
4848
"monkey_unpatch_torch_reductions",
4949
"InternalMetricsRecorder",
50+
"InternalMetrics",
5051
]

xtuner/v1/utils/internal_metrics.py

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@
1717
from xtuner.v1.model import MoE
1818
from xtuner.v1.model.base import ModelItem
1919
from xtuner.v1.engine.train_engine import TrainEngine
20+
from xtuner.v1.utils.grad_norm import group_tensors_by_device_mesh_and_placements, cal_total_norm
21+
22+
from typing_extensions import TypedDict
23+
24+
25+
class InternalMetrics(TypedDict):
26+
weight_rms: dict[str, float]
27+
maxvio: dict[str, float]
28+
drop_ratio: dict[str, float]
29+
router_logits_max: dict[str, float]
30+
router_logits_mean: dict[str, float]
31+
attn_max_lse: dict[str, float]
32+
attn_max_logits: dict[str, float]
33+
2034

2135
RMS_NORM_MONITOR_MODULES = (
2236
nn.Embedding,
@@ -34,36 +48,32 @@ def __init__(self, engine: TrainEngine):
3448
self.intra_layer_micro_batch = engine.intra_layer_micro_batch
3549
self.hooks: list[RemovableHandle] = []
3650
# TODO: refactor with TypeDict
37-
self.metrics: dict[str, dict[str, Any]] = dict[str, dict[str, Any]](
38-
weight_rms=dict[str, Any](),
39-
maxvio=dict[str, Any](),
40-
drop_ratio=dict[str, Any](),
41-
router_logits_max=dict[str, Any](),
42-
router_logits_mean=dict[str, Any](),
43-
attn_max_lse=dict[str, Any](),
44-
attn_max_logits=dict[str, Any](),
45-
)
46-
47-
def register_weight_rms_hook(self, module: nn.Module, layer_name: str):
48-
"""
49-
Register weight RMS hook as a pre-forward hook, as at this point, the parameters are should be
50-
all-gathered into current rank.
51-
"""
52-
def hook(module, args, kwargs=None):
53-
if layer_name in self.metrics['weight_rms']: # only calculate before the first batch
54-
return
55-
l2_norm = 0.0
56-
total_params = 0
57-
for param in module.parameters():
58-
if param.requires_grad:
59-
l2_norm += torch.norm(param.detach().float(), p=2) ** 2
60-
total_params += param.numel()
61-
if total_params > 0:
62-
rms = torch.sqrt(l2_norm / total_params)
63-
self.metrics['weight_rms'][layer_name] = rms
64-
65-
hook_handle: RemovableHandle = module.register_forward_pre_hook(hook)
66-
self.hooks.append(hook_handle)
51+
self.metrics: InternalMetrics = {
52+
"weight_rms": {},
53+
"maxvio": {},
54+
"drop_ratio": {},
55+
"router_logits_max": {},
56+
"router_logits_mean": {},
57+
"attn_max_lse": {},
58+
"attn_max_logits": {},
59+
}
60+
self.attn_max_lse: dict[str, torch.Tensor] = {}
61+
self.attn_max_logits: dict[str, torch.Tensor] = {}
62+
63+
def calculate_module_weight_rms(self, module: nn.Module, layer_name: str, dtype: torch.dtype = torch.float32):
64+
all_params = [param for param in module.parameters() if param.requires_grad]
65+
if not all_params:
66+
return
67+
grouped_params = group_tensors_by_device_mesh_and_placements(all_params)
68+
total_norms = []
69+
total_numel = 0
70+
for params in grouped_params.values():
71+
total_norm = cal_total_norm(params, norm_type=2.0, foreach=True, dtype=dtype)
72+
total_norms.append(total_norm)
73+
total_numel += sum(p.numel() for p in params)
74+
param_l2_norm = torch.linalg.vector_norm(torch.stack(total_norms), ord=2.0, dtype=dtype)
75+
param_rms = param_l2_norm / total_numel**0.5
76+
self.metrics['weight_rms'][layer_name] = param_rms.item()
6777

6878
def register_attn_extra_info_hook(self, module: nn.Module, layer_name: str):
6979
"""
@@ -72,19 +82,19 @@ def register_attn_extra_info_hook(self, module: nn.Module, layer_name: str):
7282
def hook(module, input, output):
7383
extra_info = output[1]
7484
if extra_info.get("softmax_lse", None) is not None:
75-
if layer_name not in self.metrics["attn_max_lse"]:
85+
if layer_name not in self.attn_max_lse:
7686
# original shape: [n_head, seq]
77-
self.metrics["attn_max_lse"][layer_name] = extra_info["softmax_lse"].max()
87+
self.attn_max_lse[layer_name] = extra_info["softmax_lse"].max()
7888
else:
79-
prev_lse_max = self.metrics["attn_max_lse"][layer_name]
80-
self.metrics["attn_max_lse"][layer_name] = max(prev_lse_max, extra_info["softmax_lse"].max())
89+
prev_lse_max = self.attn_max_lse[layer_name]
90+
self.attn_max_lse[layer_name] = max(prev_lse_max, extra_info["softmax_lse"].max())
8191
if extra_info.get("attn_logits", None) is not None:
82-
if layer_name not in self.metrics["attn_max_logits"]:
92+
if layer_name not in self.attn_max_logits:
8393
# original shape: [b, n_head, seq, seq]
84-
self.metrics["attn_max_logits"][layer_name] = extra_info["attn_logits"].max()
94+
self.attn_max_logits[layer_name] = extra_info["attn_logits"].max()
8595
else:
86-
prev_logits_max = self.metrics["attn_max_logits"][layer_name]
87-
self.metrics["attn_max_logits"][layer_name] = max(prev_logits_max, extra_info["attn_logits"].max())
96+
prev_logits_max = self.attn_max_logits[layer_name]
97+
self.attn_max_logits[layer_name] = max(prev_logits_max, extra_info["attn_logits"].max())
8898

8999
hook_handle: RemovableHandle = module.register_forward_hook(hook)
90100
self.hooks.append(hook_handle)
@@ -150,8 +160,8 @@ def get_metrics(self, data_batches: list[ModelItem]):
150160
{f"layer{idx}": maxvio_all_layers[idx].item() for idx in range(max_load_i.shape[0])}
151161
)
152162
maxvio = maxvio_all_layers.mean()
153-
self.metrics["maxvio"]["total"] = maxvio
154-
self.metrics["drop_ratio"]["total"] = drop_ratio
163+
self.metrics["maxvio"]["total"] = maxvio.item()
164+
self.metrics["drop_ratio"]["total"] = drop_ratio.item()
155165

156166
if router_logits_max:
157167
for layer_name, router_logits_list in router_logits_max.items():
@@ -168,12 +178,12 @@ def get_metrics(self, data_batches: list[ModelItem]):
168178
self.metrics["router_logits_mean"][layer_name] = local_router_logits_mean.item()
169179

170180
if self.metrics["attn_max_lse"]:
171-
for layer_name, local_attn_max_lse in self.metrics["attn_max_lse"].items():
181+
for layer_name, local_attn_max_lse in self.attn_max_lse.items():
172182
dist.all_reduce(local_attn_max_lse, op=dist.ReduceOp.MAX)
173183
self.metrics["attn_max_lse"][layer_name] = local_attn_max_lse.item()
174184

175-
if self.metrics["attn_max_logits"]:
176-
for layer_name, local_attn_max_logits in self.metrics["attn_max_logits"].items():
185+
if self.attn_max_logits:
186+
for layer_name, local_attn_max_logits in self.attn_max_logits.items():
177187
dist.all_reduce(local_attn_max_logits, op=dist.ReduceOp.MAX)
178188
self.metrics["attn_max_logits"][layer_name] = local_attn_max_logits.item()
179189

@@ -184,7 +194,7 @@ def __enter__(self):
184194
if isinstance(module, ATTENTION_CLS):
185195
self.register_attn_extra_info_hook(module, self._clean_module_name(name))
186196
if isinstance(module, RMS_NORM_MONITOR_MODULES):
187-
self.register_weight_rms_hook(module, self._clean_module_name(name))
197+
self.calculate_module_weight_rms(module, self._clean_module_name(name), dtype=torch.float32)
188198

189199
return self
190200

0 commit comments

Comments
 (0)