Skip to content

Commit 7593020

Browse files
m-bridge for nemo container 26.02
1 parent 7b63c79 commit 7593020

File tree

8 files changed

+202
-43
lines changed

8 files changed

+202
-43
lines changed

conf/experimental/megatron_bridge/test/b200/megatron_bridge_qwen_30b.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ mount_as = "/opt/Megatron-Bridge"
2828
[cmd_args]
2929
gpu_type = "b200"
3030
container_image = "nvcr.io#nvidia/nemo:25.11.01"
31-
model_name = "qwen3"
32-
model_size = "30b_a3b"
31+
model_family_name = "qwen3"
32+
model_recipe_name = "30b_a3b"
3333
gpus_per_node = 4
3434
num_gpus = 8
3535
domain = "llm"

conf/experimental/megatron_bridge/test/gb200/megatron_bridge_qwen_30b.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ mount_as = "/opt/Megatron-Bridge"
2828
[cmd_args]
2929
gpu_type = "gb200"
3030
container_image = "nvcr.io#nvidia/nemo:25.11.01"
31-
model_name = "qwen3"
32-
model_size = "30b_a3b"
31+
model_family_name = "qwen3"
32+
model_recipe_name = "30b_a3b"
3333
gpus_per_node = 4
3434
num_gpus = 8
3535
domain = "llm"

conf/experimental/megatron_bridge/test/gb300/megatron_bridge_qwen_30b.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ mount_as = "/opt/Megatron-Bridge"
2828
[cmd_args]
2929
gpu_type = "gb300"
3030
container_image = "nvcr.io#nvidia/nemo:25.11.01"
31-
model_name = "qwen3"
32-
model_size = "30b_a3b"
31+
model_family_name = "qwen3"
32+
model_recipe_name = "30b_a3b"
3333
gpus_per_node = 4
3434
num_gpus = 8
3535
domain = "llm"

conf/experimental/megatron_bridge/test/h100/megatron_bridge_qwen_30b.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ mount_as = "/opt/Megatron-Bridge"
2828
[cmd_args]
2929
gpu_type = "h100"
3030
container_image = "nvcr.io#nvidia/nemo:25.11.01"
31-
model_name = "qwen3"
32-
model_size = "30b_a3b"
31+
model_family_name = "qwen3"
32+
model_recipe_name = "30b_a3b"
3333
gpus_per_node = 8
3434
num_gpus = 16
3535
domain = "llm"

doc/workloads/megatron_bridge.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ Test TOML example:
1919
# Container can be an NGC/enroot URL (nvcr.io#...) or a local .sqsh path.
2020
container_image = "nvcr.io#nvidia/nemo:25.11.01"
2121
22-
model_name = "qwen3"
23-
model_size = "30b_a3b"
22+
model_family_name = "qwen3"
23+
model_recipe_name = "30b_a3b"
2424
task = "pretrain"
2525
domain = "llm"
2626
compute_dtype = "fp8_mx"
@@ -55,8 +55,8 @@ Test-in-Scenario example:
5555
5656
[Tests.cmd_args]
5757
container_image = "nvcr.io#nvidia/nemo:25.11.01"
58-
model_name = "qwen3"
59-
model_size = "30b_a3b"
58+
model_family_name = "qwen3"
59+
model_recipe_name = "30b_a3b"
6060
task = "pretrain"
6161
domain = "llm"
6262
compute_dtype = "fp8_mx"

src/cloudai/workloads/megatron_bridge/megatron_bridge.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,23 @@ class MegatronBridgeCmdArgs(CmdArgs):
4040
detach: Optional[bool] = Field(default=None)
4141

4242
# Model/task
43-
model_name: str = Field(default="")
44-
model_size: str = Field(default="")
43+
model_family_name: str = Field(default="")
44+
model_recipe_name: str = Field(default="")
45+
use_recipes: Optional[bool] = Field(default=None)
4546
domain: str = Field(default="llm")
4647
task: str = Field(default="pretrain")
4748
compute_dtype: str = Field(default="bf16")
4849
fp8_recipe: Optional[str] = Field(default=None)
4950
hf_token: Optional[str] = Field(default=None)
5051
nemo_home: Optional[str] = Field(default=None)
5152
wandb_key: Optional[str] = Field(default=None)
52-
wandb_prj_name: Optional[str] = Field(default=None)
53-
wandb_exp_name: Optional[str] = Field(default=None)
53+
wandb_project_name: Optional[str] = Field(default=None)
54+
wandb_entity_name: Optional[str] = Field(default=None)
55+
wandb_experiment_name: Optional[str] = Field(default=None)
56+
wandb_save_dir: Optional[str] = Field(default=None)
57+
58+
# Retries
59+
max_retries: Optional[int] = Field(default=None)
5460

5561
# Feature flags (allow sweeps)
5662
use_tokendrop: Optional[Union[bool, List[bool]]] = Field(default=None)
@@ -69,6 +75,43 @@ class MegatronBridgeCmdArgs(CmdArgs):
6975
# Batch sizes
7076
mb: Optional[Union[int, List[int]]] = Field(default=None)
7177
gb: Optional[Union[int, List[int]]] = Field(default=None)
78+
seq_length: Optional[Union[int, List[int]]] = Field(default=None)
79+
80+
# Optimizer
81+
lr: Optional[Union[float, List[float]]] = Field(default=None)
82+
min_lr: Optional[Union[float, List[float]]] = Field(default=None)
83+
warmup_iters: Optional[Union[int, List[int]]] = Field(default=None)
84+
85+
# Checkpointing
86+
pretrained_checkpoint: Optional[str] = Field(default=None)
87+
save_dir: Optional[str] = Field(default=None)
88+
load_dir: Optional[str] = Field(default=None)
89+
save_interval: Optional[int] = Field(default=None)
90+
most_recent_k: Optional[int] = Field(default=None)
91+
save_config_filepath: Optional[str] = Field(default=None)
92+
93+
# Data / Tokenizer
94+
data: Optional[str] = Field(default=None)
95+
dataset_paths: Optional[Union[str, List[str]]] = Field(default=None)
96+
dataset_root: Optional[str] = Field(default=None)
97+
index_mapping_dir: Optional[str] = Field(default=None)
98+
dataset_name: Optional[str] = Field(default=None)
99+
packed_sequence: Optional[bool] = Field(default=None)
100+
head_only: Optional[bool] = Field(default=None)
101+
tokenizer_type: Optional[str] = Field(default=None)
102+
tokenizer_model: Optional[str] = Field(default=None)
103+
vocab_size: Optional[int] = Field(default=None)
104+
105+
# Profiling (performance group in argument_parser.py)
106+
pytorch_profiler: Optional[bool] = Field(default=None)
107+
profiling_start_step: Optional[int] = Field(default=None)
108+
profiling_stop_step: Optional[int] = Field(default=None)
109+
record_memory_history: Optional[bool] = Field(default=None)
110+
profiling_gpu_metrics: Optional[bool] = Field(default=None)
111+
profiling_ranks: Optional[Union[int, List[int]]] = Field(default=None)
112+
113+
# Performance
114+
nccl_ub: Optional[Union[bool, List[bool]]] = Field(default=None)
72115

73116
# Perf/tuning
74117
moe_a2a_overlap: Optional[Union[bool, List[bool]]] = Field(default=None)

src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py

Lines changed: 66 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ def _build_launcher_parts( # noqa: C901
154154
) -> list[str]:
155155
fields_set = args.model_fields_set
156156
force_fields = {
157-
"model_name",
158-
"model_size",
157+
"model_family_name",
158+
"model_recipe_name",
159159
"num_gpus",
160160
"gpus_per_node",
161161
"hf_token",
@@ -198,6 +198,15 @@ def add(flag: str, value: Any) -> None:
198198
return
199199
if isinstance(value, bool):
200200
parts.extend([flag, "true" if value else "false"])
201+
elif isinstance(value, (list, tuple)):
202+
if not value:
203+
return
204+
if flag == "--dataset_paths":
205+
parts.extend([flag, *[str(x) for x in value]])
206+
elif flag == "--profiling_ranks":
207+
parts.extend([flag, ",".join(str(x) for x in value)])
208+
else:
209+
parts.extend([flag, str(value[0]) if len(value) == 1 else ",".join(str(x) for x in value)])
201210
else:
202211
sv = str(value)
203212
if sv != "":
@@ -222,31 +231,38 @@ def add_field(field: str, flag: str, value: Any) -> None:
222231
add_field("hf_token", "-hf", args.hf_token)
223232
add_field("nemo_home", "-nh", args.nemo_home)
224233
add_field("wandb_key", "-wdk", args.wandb_key)
225-
add_field("wandb_prj_name", "-wdp", args.wandb_prj_name)
226-
add_field("wandb_exp_name", "-wdj", args.wandb_exp_name)
234+
add_field("wandb_project_name", "-wdp", args.wandb_project_name)
235+
add_field("wandb_entity_name", "-wde", args.wandb_entity_name)
236+
add_field("wandb_experiment_name", "-wdj", args.wandb_experiment_name)
237+
add_field("wandb_save_dir", "-wds", args.wandb_save_dir)
238+
add_field("max_retries", "--max_retries", args.max_retries)
227239
if args.dryrun and "dryrun" in fields_set:
228240
parts.append("-d")
229241
add_field("num_gpus", "-ng", args.num_gpus)
230242
add_field("gpus_per_node", "-gn", args.gpus_per_node)
231243
if mounts:
232244
add("-cm", ",".join(mounts))
233245

234-
# Model flags (Megatron-Bridge r0.2.0 API)
246+
# Model flags (Megatron-Bridge main-branch API)
247+
if args.use_recipes and "use_recipes" in fields_set:
248+
parts.append("--use_recipes")
235249
if "enable_vboost" in fields_set:
236250
add_field("enable_vboost", "-vb", bool(args.enable_vboost))
237-
if not args.model_name:
238-
raise RuntimeError("Missing required cmd_args.model_name (maps to -m/--model_name).")
239-
if not args.model_size:
240-
raise RuntimeError("Missing required cmd_args.model_size (maps to -s/--model_size).")
241-
add_field("model_name", "-m", args.model_name)
242-
add_field("model_size", "-s", args.model_size)
251+
if not args.model_family_name:
252+
raise RuntimeError("Missing required cmd_args.model_family_name (maps to -m/--model_family_name).")
253+
if not args.model_recipe_name:
254+
raise RuntimeError("Missing required cmd_args.model_recipe_name (maps to -mr/--model_recipe_name).")
255+
add_field("model_family_name", "-m", args.model_family_name)
256+
add_field("model_recipe_name", "-mr", args.model_recipe_name)
243257
if args.enable_nsys and "enable_nsys" in fields_set:
244258
parts.append("-en")
245259
add_field("domain", "--domain", args.domain)
246260
if "use_tokendrop" in fields_set and args.use_tokendrop is not None:
247261
add_field("use_tokendrop", "--use_tokendrop", bool(args.use_tokendrop))
248262
if "use_megatron_fsdp" in fields_set and args.use_megatron_fsdp is not None:
249263
add_field("use_megatron_fsdp", "--use_megatron_fsdp", bool(args.use_megatron_fsdp))
264+
if "nccl_ub" in fields_set and args.nccl_ub is not None:
265+
add_field("nccl_ub", "--nccl_ub", bool(args.nccl_ub))
250266
add_field("cuda_graph_impl", "--cuda_graph_impl", args.cuda_graph_impl)
251267
if args.cuda_graph_scope and "cuda_graph_scope" in fields_set:
252268
add_field(
@@ -264,6 +280,7 @@ def add_field(field: str, flag: str, value: Any) -> None:
264280
# Batch
265281
add_field("mb", "-mb", args.mb)
266282
add_field("gb", "-gb", args.gb)
283+
add_field("seq_length", "-sl", args.seq_length)
267284

268285
# Misc
269286
if "moe_a2a_overlap" in fields_set:
@@ -273,11 +290,44 @@ def add_field(field: str, flag: str, value: Any) -> None:
273290
add_field("activation_offload_layers", "-ol", args.activation_offload_layers)
274291
if args.recompute_modules and "recompute_modules" in fields_set:
275292
parts.extend(["--recompute_modules", self._normalize_recompute_modules(args.recompute_modules)])
276-
# r0.2.0 supports `--detach` / `--no-detach` flags (no boolean value)
277-
if args.detach is True and "detach" in fields_set:
278-
parts.append("--detach")
279-
elif args.detach is False and "detach" in fields_set:
280-
parts.append("--no-detach")
293+
if "detach" in fields_set and args.detach is not None:
294+
parts.extend(["--detach", "true" if args.detach else "false"])
295+
296+
# Optimizer
297+
add_field("lr", "--lr", args.lr)
298+
add_field("min_lr", "--min_lr", args.min_lr)
299+
add_field("warmup_iters", "--warmup_iters", args.warmup_iters)
300+
301+
# Checkpointing
302+
add_field("pretrained_checkpoint", "--pretrained_checkpoint", args.pretrained_checkpoint)
303+
add_field("save_dir", "--save_dir", args.save_dir)
304+
add_field("load_dir", "--load_dir", args.load_dir)
305+
add_field("save_interval", "--save_interval", args.save_interval)
306+
add_field("most_recent_k", "--most_recent_k", args.most_recent_k)
307+
add_field("save_config_filepath", "--save_config_filepath", args.save_config_filepath)
308+
309+
# Data / Tokenizer
310+
add_field("data", "--data", args.data)
311+
add_field("dataset_paths", "--dataset_paths", args.dataset_paths)
312+
add_field("dataset_root", "--dataset_root", args.dataset_root)
313+
add_field("index_mapping_dir", "--index_mapping_dir", args.index_mapping_dir)
314+
add_field("dataset_name", "--dataset_name", args.dataset_name)
315+
if args.packed_sequence and "packed_sequence" in fields_set:
316+
parts.append("--packed_sequence")
317+
if args.head_only and "head_only" in fields_set:
318+
parts.append("--head_only")
319+
add_field("tokenizer_type", "--tokenizer_type", args.tokenizer_type)
320+
add_field("tokenizer_model", "--tokenizer_model", args.tokenizer_model)
321+
add_field("vocab_size", "--vocab_size", args.vocab_size)
322+
323+
# Profiling (performance group)
324+
add_field("pytorch_profiler", "-pyp", args.pytorch_profiler)
325+
add_field("profiling_start_step", "--profiling_start_step", args.profiling_start_step)
326+
add_field("profiling_stop_step", "--profiling_stop_step", args.profiling_stop_step)
327+
add_field("record_memory_history", "-mh", args.record_memory_history)
328+
if args.profiling_gpu_metrics and "profiling_gpu_metrics" in fields_set:
329+
parts.append("--profiling_gpu_metrics")
330+
add_field("profiling_ranks", "--profiling_ranks", args.profiling_ranks)
281331

282332
# Extra user args (dict -> string)
283333
if tdef.extra_cmd_args:

0 commit comments

Comments
 (0)