Skip to content

Commit 83748bb

Browse files
authored
feat(torchtitan): enable dynamic model parameter override via CLI (#254)
1 parent de8f429 commit 83748bb

File tree

5 files changed

+140
-9
lines changed

5 files changed

+140
-9
lines changed

examples/megatron/prepare.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def main():
191191

192192
log_info(f"BACKEND_PATH {args.backend_path}")
193193
# primus_config = PrimusParser().parse(args)
194-
primus_config = load_primus_config(args, unknown)
194+
primus_config, _ = load_primus_config(args, unknown)
195195

196196
primus_path = Path(args.primus_path).resolve()
197197
log_info(f"PRIMUS_PATH is set to: {primus_path}")

primus/core/launcher/parser.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from pathlib import Path
44
from types import SimpleNamespace
5-
from typing import List
5+
from typing import Any, Dict, List, Tuple
66

77
from primus.core.launcher.config import PrimusConfig
88
from primus.core.utils import constant_vars, yaml_utils
@@ -142,6 +142,30 @@ def _check_keys_exist(ns: SimpleNamespace, overrides: dict, prefix=""):
142142
_check_keys_exist(attr_val, v, prefix=full_key)
143143

144144

145+
def _split_known_unknown(ns: SimpleNamespace, overrides: dict) -> Tuple[dict, dict]:
146+
"""
147+
Split overrides into two dictionaries:
148+
- known: keys that exist in the namespace
149+
- unknown: keys not defined in the namespace
150+
"""
151+
known, unknown = {}, {}
152+
for k, v in overrides.items():
153+
if hasattr(ns, k):
154+
attr_val = getattr(ns, k)
155+
if isinstance(v, dict) and isinstance(attr_val, SimpleNamespace):
156+
sub_known, sub_unknown = _split_known_unknown(attr_val, v)
157+
if sub_known:
158+
known[k] = sub_known
159+
if sub_unknown:
160+
unknown[k] = sub_unknown
161+
else:
162+
known[k] = v
163+
else:
164+
unknown[k] = v
165+
# print(f"[PrimusConfig] Unknown key '{k}' delegated to backend.")
166+
return known, unknown
167+
168+
145169
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
146170
args, unknown_args = _parse_args(extra_args_provider, ignore_unknown_args=True)
147171

@@ -156,7 +180,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
156180
return primus_config
157181

158182

159-
def load_primus_config(args: argparse.Namespace, overrides: List[str]) -> PrimusConfig:
183+
def load_primus_config(args: argparse.Namespace, overrides: List[str]) -> Tuple[Any, Dict[str, Any]]:
160184
"""
161185
Build the Primus configuration with optional command-line overrides.
162186
@@ -177,10 +201,19 @@ def load_primus_config(args: argparse.Namespace, overrides: List[str]) -> Primus
177201

178202
# 3 Apply overrides to pre_trainer module config
179203
pre_trainer_cfg = primus_config.get_module_config("pre_trainer")
180-
_check_keys_exist(pre_trainer_cfg, override_ns)
181-
_deep_merge_namespace(pre_trainer_cfg, override_ns)
204+
# _check_keys_exist(pre_trainer_cfg, override_ns)
205+
# _deep_merge_namespace(pre_trainer_cfg, override_ns)
182206

183-
return primus_config
207+
# return primus_config
208+
known_overrides, unknown_overrides = _split_known_unknown(pre_trainer_cfg, override_ns)
209+
210+
if known_overrides:
211+
_deep_merge_namespace(pre_trainer_cfg, known_overrides)
212+
213+
if unknown_overrides:
214+
print(f"[PrimusConfig] Detected unknown override keys: {list(unknown_overrides.keys())}")
215+
216+
return primus_config, unknown_overrides
184217

185218

186219
class PrimusParser(object):

primus/modules/trainer/megatron/pre_trainer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,15 @@ def pop(cls):
118118
class MegatronPretrainTrainer(MegatronTrainer):
119119
def __init__(self, *args, **kwargs):
120120
kwargs["module_name"] = "pre_trainer"
121+
122+
# Explicitly reject unknown extra_args
123+
extra_args = kwargs.pop("extra_args", None)
124+
if extra_args:
125+
raise ValueError(
126+
f"[MegatronPretrainTrainer] Unexpected extra_args detected: {extra_args}. "
127+
f"Megatron backend does not support unregistered config keys."
128+
)
129+
121130
super().__init__(*args, **kwargs)
122131

123132
def get_batch(self, data_iterator):

primus/modules/trainer/torchtitan/pre_trainer.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
class TorchTitanPretrainTrainer(BaseModule):
1515
def __init__(self, *args, **kwargs):
16+
extra_args = kwargs.pop("extra_args", None)
1617
super().__init__(*args, **kwargs)
1718

1819
# important: make sure patch torchtitan logger first
@@ -26,6 +27,7 @@ def __init__(self, *args, **kwargs):
2627
cfg_dict = nested_namespace_to_dict(pre_trainer_cfg)
2728

2829
self.patch_torchtitan_embedding_amp(cfg_dict["primus_turbo"]["enable_embedding_autocast"])
30+
self.patch_titan_train_spec(pre_trainer_cfg.model.name, pre_trainer_cfg.model.flavor, extra_args)
2931

3032
# ensure checkpoint patch applied before import torchtitan
3133
# background: consolidate_safetensors_files_on_every_rank is a new DCP
@@ -508,3 +510,89 @@ def new_init(self, *args, **kwargs):
508510
primus_logger.info(
509511
"[PrimusPatch][AMP] nn.Embedding.__init__ patched for AMP/mixed precision alignment."
510512
)
513+
514+
def patch_titan_train_spec(self, model_name: str, flavor: str, model_overrides: Dict[str, Any]):
515+
"""
516+
Monkey patch torchtitan.train_spec.get_train_spec to override model args dynamically.
517+
All override keys MUST start with "model." (e.g., {"model.n_layers": 8}).
518+
"""
519+
from primus.core.utils.logger import _logger as primus_logger
520+
521+
if not model_overrides:
522+
primus_logger.info("[PrimusPatch][ModelOverride] No model_overrides provided, skip patch.")
523+
return
524+
525+
primus_logger.info(f"[PrimusPatch][ModelOverride] Applying model_overrides: {model_overrides}")
526+
527+
# --- flatten nested form {"model": {"n_layers": 4}} → {"model.n_layers": 4}
528+
flat_overrides = {}
529+
for k, v in model_overrides.items():
530+
if k == "model" and isinstance(v, dict):
531+
for subk, subv in v.items():
532+
flat_overrides[f"model.{subk}"] = subv
533+
else:
534+
flat_overrides[k] = v
535+
model_overrides = flat_overrides
536+
537+
# Enforce `model.` prefix strictly
538+
bad_keys = [k for k in model_overrides.keys() if not k.startswith("model.")]
539+
if bad_keys:
540+
raise ValueError(
541+
# f"[PrimusPatch][ModelOverride] Unsupported override keys (must start with 'model.'): {bad_keys}"
542+
f"[PrimusPatch][ModelOverride] Invalid override keys detected: {bad_keys}. "
543+
"These parameters belong to the model configuration and must be specified "
544+
"with the 'model.' prefix (e.g., 'model.n_layers', 'model.dim')."
545+
)
546+
547+
primus_logger.info(
548+
f"[PrimusPatch][ModelOverride] model_overrides provided for '{model_name}' (flavor={flavor}): {model_overrides}"
549+
)
550+
551+
import torchtitan.protocols.train_spec as train_spec_module
552+
553+
orig_get_train_spec = train_spec_module.get_train_spec
554+
555+
def patched_get_train_spec(name: str):
556+
spec = orig_get_train_spec(name)
557+
if name != model_name:
558+
return spec # only patch targeted model
559+
560+
assert hasattr(
561+
spec, "model_args"
562+
), f"[PrimusPatch][ModelOverride] train_spec for '{name}' missing model_args"
563+
model_args_root = spec.model_args
564+
assert isinstance(
565+
model_args_root, dict
566+
), f"[PrimusPatch][ModelOverride] train_spec.model_args must be dict, got {type(model_args_root)}"
567+
568+
if flavor not in model_args_root:
569+
raise KeyError(
570+
f"[PrimusPatch][ModelOverride] flavor '{flavor}' not found in model_args for '{name}'. "
571+
f"Available flavors: {list(model_args_root.keys())}"
572+
)
573+
574+
target_args = model_args_root[flavor]
575+
assert is_dataclass(
576+
target_args
577+
), f"[PrimusPatch][ModelOverride] Expected dataclass model_args, got {type(target_args)}"
578+
579+
before = asdict(target_args)
580+
for k, v in model_overrides.items():
581+
field_name = k[len("model.") :]
582+
if not hasattr(target_args, field_name):
583+
raise AttributeError(
584+
f"[PrimusPatch][ModelOverride] '{type(target_args).__name__}' has no field '{field_name}'"
585+
)
586+
setattr(target_args, field_name, v)
587+
588+
primus_logger.info(
589+
f"[PrimusPatch][ModelOverride] Patched dataclass model_args['{flavor}'] "
590+
f"for '{name}' with {model_overrides} (before={before})"
591+
)
592+
return spec
593+
594+
# Apply the patch globally
595+
train_spec_module.get_train_spec = patched_get_train_spec
596+
primus_logger.info(
597+
f"[PrimusPatch][ModelOverride] get_train_spec for '{model_name}' successfully monkey patched (flavor={flavor})."
598+
)

primus/pretrain.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def setup_env(data_path: str):
9898
print(f"[Primus CLI] HF_HOME already set: {hf_home}")
9999

100100

101-
def launch_pretrain_trainer(primus_cfg: PrimusConfig):
101+
def launch_pretrain_trainer(primus_cfg: PrimusConfig, extra_args=None):
102102
"""
103103
Launch the training using the Primus trainer.
104104
@@ -126,6 +126,7 @@ def launch_pretrain_trainer(primus_cfg: PrimusConfig):
126126
module_world_size=world_size,
127127
module_master_addr=master_addr,
128128
module_master_port=master_port,
129+
extra_args=extra_args,
129130
)
130131

131132
# Launch training
@@ -150,7 +151,7 @@ def launch_pretrain_from_cli(args, overrides):
150151

151152
setup_env(data_path=args.data_path)
152153

153-
primus_cfg = load_primus_config(args, overrides)
154+
primus_cfg, unknown_overrides = load_primus_config(args, overrides)
154155

155156
# Export merged config if requested
156157
if args.export_config:
@@ -160,7 +161,7 @@ def launch_pretrain_from_cli(args, overrides):
160161
framework = primus_cfg.get_module_config("pre_trainer").framework
161162
setup_backend_path(framework=framework, backend_path=args.backend_path, verbose=True)
162163

163-
launch_pretrain_trainer(primus_cfg=primus_cfg)
164+
launch_pretrain_trainer(primus_cfg=primus_cfg, extra_args=unknown_overrides)
164165

165166

166167
if __name__ == "__main__":

0 commit comments

Comments
 (0)