Skip to content

Commit 39f7de0

Browse files
committed
[Fix][Temp] Avoid recompiling on globals and closure variables
1 parent fbfd5f5 commit 39f7de0

5 files changed

Lines changed: 21 additions & 12 deletions

File tree

xtuner/v1/module/decoder_layer/dense_decoder_layer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(
6262
generate_config=generate_config,
6363
float8_cfg=float8_cfg,
6464
)
65+
self.self_attn.name = f"layers.{layer_idx}.self_attn"
6566
self.mlp = DenseMLP(
6667
hidden_size=hidden_size,
6768
intermediate_size=intermediate_size,

xtuner/v1/module/decoder_layer/moe_decoder_layer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def __init__(
214214
layer_type=layer_type,
215215
float8_cfg=float8_cfg,
216216
)
217+
self.self_attn.name = f"layers.{layer_idx}.self_attn"
217218
self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
218219
self.shared_experts: MoEMLP | None
219220
self.layer_idx = layer_idx

xtuner/v1/train/trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,10 @@ def _maybe_check_model_internal_metrics(self, data_batches: list[ModelItem]) ->
620620
return None
621621

622622
with InternalMetricsRecorder(self._engine) as metrics_recorder:
623-
return metrics_recorder.get_metrics(data_batches)
623+
logger.info("Start calculating model internal metrics...")
624+
metrics: InternalMetrics = metrics_recorder.get_metrics(data_batches)
625+
logger.info("Calculating model internal metrics done.")
626+
return metrics
624627

625628
@property
626629
def world_size(self) -> int:

xtuner/v1/utils/compile.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ def decorator(func):
7070
func_compile_kwargs = {**compile_kwargs, **target_kwargs}
7171

7272
# Compile the function
73-
self._compiled_funcs[func_id] = torch.compile(original_func, **func_compile_kwargs)
73+
self._compiled_funcs[func_id] = torch.compile(
74+
original_func,
75+
options={"guard_filter_fn": torch.compiler.skip_guard_on_globals_unsafe},
76+
**func_compile_kwargs,
77+
)
7478

7579
@functools.wraps(func)
7680
def wrapper(*args, **kwargs):

xtuner/v1/utils/internal_metrics.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,24 +70,24 @@ def calculate_module_weight_rms(self, module: nn.Module, layer_name: str, dtype:
7070
param_rms = param_l2_norm / total_numel**0.5
7171
self.metrics["weight_rms"][layer_name] = param_rms.item()
7272

73-
def register_attn_extra_info_hook(self, module: nn.Module, layer_name: str):
73+
def register_attn_extra_info_hook(self, module: nn.Module):
7474
"""Register attention extra info hook as a forward hook"""
7575
def hook(module, input, output):
7676
extra_info = output[1]
7777
if extra_info.get("softmax_lse", None) is not None:
78-
if layer_name not in ATTN_MAX_LSE:
78+
if module.name not in ATTN_MAX_LSE:
7979
# original shape: [n_head, seq]
80-
ATTN_MAX_LSE[layer_name] = extra_info["softmax_lse"].max()
80+
ATTN_MAX_LSE[module.name] = extra_info["softmax_lse"].max()
8181
else:
82-
prev_lse_max = ATTN_MAX_LSE[layer_name]
83-
ATTN_MAX_LSE[layer_name] = max(prev_lse_max, extra_info["softmax_lse"].max())
82+
prev_lse_max = ATTN_MAX_LSE[module.name]
83+
ATTN_MAX_LSE[module.name] = max(prev_lse_max, extra_info["softmax_lse"].max())
8484
if extra_info.get("attn_logits", None) is not None:
85-
if layer_name not in ATTN_MAX_LOGITS:
85+
if module.name not in ATTN_MAX_LOGITS:
8686
# original shape: [b, n_head, seq, seq]
87-
ATTN_MAX_LOGITS[layer_name] = extra_info["attn_logits"].max()
87+
ATTN_MAX_LOGITS[module.name] = extra_info["attn_logits"].max()
8888
else:
89-
prev_logits_max = ATTN_MAX_LOGITS[layer_name]
90-
ATTN_MAX_LOGITS[layer_name] = max(prev_logits_max, extra_info["attn_logits"].max())
89+
prev_logits_max = ATTN_MAX_LOGITS[module.name]
90+
ATTN_MAX_LOGITS[module.name] = max(prev_logits_max, extra_info["attn_logits"].max())
9191

9292
hook_handle: RemovableHandle = module.register_forward_hook(hook)
9393
self.hooks.append(hook_handle)
@@ -187,7 +187,7 @@ def get_metrics(self, data_batches: list[ModelItem]):
187187
def __enter__(self):
188188
for name, module in self.model.named_modules():
189189
if isinstance(module, ATTENTION_CLS):
190-
self.register_attn_extra_info_hook(module, self._clean_module_name(name))
190+
self.register_attn_extra_info_hook(module)
191191
if isinstance(module, RMS_NORM_MONITOR_MODULES):
192192
self.calculate_module_weight_rms(module, self._clean_module_name(name), dtype=torch.float32)
193193
return self

0 commit comments

Comments
 (0)