@@ -18,6 +18,8 @@ def __init__(self, *args, **kwargs):
1818 # important: make sure patch torchtitan logger first
1919 self .patch_torchtitan_logger ()
2020
21+ self .patch_torchtitan_embedding_amp ()
22+
2123 # ensure checkpoint patch applied before import torchtitan
2224 # background: consolidate_safetensors_files_on_every_rank is a new DCP
2325 # utility introduced in newer torch versions. our current build does not
@@ -465,3 +467,46 @@ def _dict_to_dataclass(self, cls, data: dict[str, Any]) -> Any:
465467 else :
466468 init_values [f .name ] = val
467469 return cls (** init_values )
470+
471+ def patch_torchtitan_embedding_amp (self ):
472+ """
473+ Monkey patch for AMP precision mismatch in nn.Embedding.
474+
475+ Behavior:
476+ Globally patches nn.Embedding.__init__ to register a forward hook that:
477+ - When AMP/autocast is active, casts outputs to AMP dtype (bf16/fp16).
478+ - Otherwise, uses mixed_precision_param from Titan config.
479+ - Can be disabled via env: export PRIMUS_EMBED_AUTOCAST_DTYPE=off
480+ """
481+ import os
482+
483+ import torch
484+ import torch .nn as nn
485+
486+ from primus .core .utils .logger import _logger as primus_logger
487+
488+ env_flag = os .getenv ("PRIMUS_EMBED_AUTOCAST_DTYPE" , "auto" ).lower ()
489+ if env_flag in ("off" , "false" , "none" ):
490+ primus_logger .info ("[PrimusPatch][AMP] Embedding AMP patch disabled via env." )
491+ return
492+
493+ def _hook (module , inp , out ):
494+ if not isinstance (out , torch .Tensor ) or not out .is_floating_point ():
495+ return out
496+
497+ if torch .is_autocast_enabled ():
498+ runtime_dtype = torch .get_autocast_gpu_dtype ()
499+ if out .dtype != runtime_dtype :
500+ return out .to (runtime_dtype )
501+ return out
502+
503+ orig_init = nn .Embedding .__init__
504+
505+ def new_init (self , * args , ** kwargs ):
506+ orig_init (self , * args , ** kwargs )
507+ self .register_forward_hook (_hook )
508+
509+ nn .Embedding .__init__ = new_init
510+ primus_logger .info (
511+ "[PrimusPatch][AMP] nn.Embedding.__init__ patched for AMP/mixed precision alignment."
512+ )
0 commit comments