1313
1414class TorchTitanPretrainTrainer (BaseModule ):
1515 def __init__ (self , * args , ** kwargs ):
16+ extra_args = kwargs .pop ("extra_args" , None )
1617 super ().__init__ (* args , ** kwargs )
1718
1819 # important: make sure patch torchtitan logger first
@@ -26,6 +27,7 @@ def __init__(self, *args, **kwargs):
2627 cfg_dict = nested_namespace_to_dict (pre_trainer_cfg )
2728
2829 self .patch_torchtitan_embedding_amp (cfg_dict ["primus_turbo" ]["enable_embedding_autocast" ])
30+ self .patch_titan_train_spec (pre_trainer_cfg .model .name , pre_trainer_cfg .model .flavor , extra_args )
2931
3032 # ensure checkpoint patch applied before import torchtitan
3133 # background: consolidate_safetensors_files_on_every_rank is a new DCP
@@ -508,3 +510,89 @@ def new_init(self, *args, **kwargs):
508510 primus_logger .info (
509511 "[PrimusPatch][AMP] nn.Embedding.__init__ patched for AMP/mixed precision alignment."
510512 )
513+
514+ def patch_titan_train_spec (self , model_name : str , flavor : str , model_overrides : Dict [str , Any ]):
515+ """
516+ Monkey patch torchtitan.train_spec.get_train_spec to override model args dynamically.
517+ All override keys MUST start with "model." (e.g., {"model.n_layers": 8}).
518+ """
519+ from primus .core .utils .logger import _logger as primus_logger
520+
521+ if not model_overrides :
522+ primus_logger .info ("[PrimusPatch][ModelOverride] No model_overrides provided, skip patch." )
523+ return
524+
525+ primus_logger .info (f"[PrimusPatch][ModelOverride] Applying model_overrides: { model_overrides } " )
526+
527+ # --- flatten nested form {"model": {"n_layers": 4}} → {"model.n_layers": 4}
528+ flat_overrides = {}
529+ for k , v in model_overrides .items ():
530+ if k == "model" and isinstance (v , dict ):
531+ for subk , subv in v .items ():
532+ flat_overrides [f"model.{ subk } " ] = subv
533+ else :
534+ flat_overrides [k ] = v
535+ model_overrides = flat_overrides
536+
537+ # Enforce `model.` prefix strictly
538+ bad_keys = [k for k in model_overrides .keys () if not k .startswith ("model." )]
539+ if bad_keys :
540+ raise ValueError (
541+ # f"[PrimusPatch][ModelOverride] Unsupported override keys (must start with 'model.'): {bad_keys}"
542+ f"[PrimusPatch][ModelOverride] Invalid override keys detected: { bad_keys } . "
543+ "These parameters belong to the model configuration and must be specified "
544+ "with the 'model.' prefix (e.g., 'model.n_layers', 'model.dim')."
545+ )
546+
547+ primus_logger .info (
548+ f"[PrimusPatch][ModelOverride] model_overrides provided for '{ model_name } ' (flavor={ flavor } ): { model_overrides } "
549+ )
550+
551+ import torchtitan .protocols .train_spec as train_spec_module
552+
553+ orig_get_train_spec = train_spec_module .get_train_spec
554+
555+ def patched_get_train_spec (name : str ):
556+ spec = orig_get_train_spec (name )
557+ if name != model_name :
558+ return spec # only patch targeted model
559+
560+ assert hasattr (
561+ spec , "model_args"
562+ ), f"[PrimusPatch][ModelOverride] train_spec for '{ name } ' missing model_args"
563+ model_args_root = spec .model_args
564+ assert isinstance (
565+ model_args_root , dict
566+ ), f"[PrimusPatch][ModelOverride] train_spec.model_args must be dict, got { type (model_args_root )} "
567+
568+ if flavor not in model_args_root :
569+ raise KeyError (
570+ f"[PrimusPatch][ModelOverride] flavor '{ flavor } ' not found in model_args for '{ name } '. "
571+ f"Available flavors: { list (model_args_root .keys ())} "
572+ )
573+
574+ target_args = model_args_root [flavor ]
575+ assert is_dataclass (
576+ target_args
577+ ), f"[PrimusPatch][ModelOverride] Expected dataclass model_args, got { type (target_args )} "
578+
579+ before = asdict (target_args )
580+ for k , v in model_overrides .items ():
581+ field_name = k [len ("model." ) :]
582+ if not hasattr (target_args , field_name ):
583+ raise AttributeError (
584+ f"[PrimusPatch][ModelOverride] '{ type (target_args ).__name__ } ' has no field '{ field_name } '"
585+ )
586+ setattr (target_args , field_name , v )
587+
588+ primus_logger .info (
589+ f"[PrimusPatch][ModelOverride] Patched dataclass model_args['{ flavor } '] "
590+ f"for '{ name } ' with { model_overrides } (before={ before } )"
591+ )
592+ return spec
593+
594+ # Apply the patch globally
595+ train_spec_module .get_train_spec = patched_get_train_spec
596+ primus_logger .info (
597+ f"[PrimusPatch][ModelOverride] get_train_spec for '{ model_name } ' successfully monkey patched (flavor={ flavor } )."
598+ )
0 commit comments