Skip to content

Commit a151b20

Browse files
committed
fix(amp): patch nn.Embedding for AMP autocast alignment (bf16/fp16)
1 parent 48459a2 commit a151b20

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

primus/modules/trainer/torchtitan/pre_trainer.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def __init__(self, *args, **kwargs):
1818
# important: make sure patch torchtitan logger first
1919
self.patch_torchtitan_logger()
2020

21+
self.patch_torchtitan_embedding_amp()
22+
2123
# ensure checkpoint patch applied before import torchtitan
2224
# background: consolidate_safetensors_files_on_every_rank is a new DCP
2325
# utility introduced in newer torch versions. our current build does not
@@ -465,3 +467,46 @@ def _dict_to_dataclass(self, cls, data: dict[str, Any]) -> Any:
465467
else:
466468
init_values[f.name] = val
467469
return cls(**init_values)
470+
471+
def patch_torchtitan_embedding_amp(self):
472+
"""
473+
Monkey patch for AMP precision mismatch in nn.Embedding.
474+
475+
Behavior:
476+
Globally patches nn.Embedding.__init__ to register a forward hook that:
477+
- When AMP/autocast is active, casts outputs to AMP dtype (bf16/fp16).
478+
- Otherwise, uses mixed_precision_param from Titan config.
479+
- Can be disabled via env: export PRIMUS_EMBED_AUTOCAST_DTYPE=off
480+
"""
481+
import os
482+
483+
import torch
484+
import torch.nn as nn
485+
486+
from primus.core.utils.logger import _logger as primus_logger
487+
488+
env_flag = os.getenv("PRIMUS_EMBED_AUTOCAST_DTYPE", "auto").lower()
489+
if env_flag in ("off", "false", "none"):
490+
primus_logger.info("[PrimusPatch][AMP] Embedding AMP patch disabled via env.")
491+
return
492+
493+
def _hook(module, inp, out):
494+
if not isinstance(out, torch.Tensor) or not out.is_floating_point():
495+
return out
496+
497+
if torch.is_autocast_enabled():
498+
runtime_dtype = torch.get_autocast_gpu_dtype()
499+
if out.dtype != runtime_dtype:
500+
return out.to(runtime_dtype)
501+
return out
502+
503+
orig_init = nn.Embedding.__init__
504+
505+
def new_init(self, *args, **kwargs):
506+
orig_init(self, *args, **kwargs)
507+
self.register_forward_hook(_hook)
508+
509+
nn.Embedding.__init__ = new_init
510+
primus_logger.info(
511+
"[PrimusPatch][AMP] nn.Embedding.__init__ patched for AMP/mixed precision alignment."
512+
)

0 commit comments

Comments
 (0)