Skip to content

Commit c59fd89

Browse files
committed
fix pipeline layout + tokenizer config
1 parent b0423ff commit c59fd89

File tree

1 file changed

+17
-21
lines changed

1 file changed

+17
-21
lines changed

moe_pretraining/nemo/pretrain_deepseek_v3_671b.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
import argparse
1616
import os
1717

18-
from megatron.bridge.recipes.deepseek import deepseek_v3_pretrain_config
18+
from megatron.bridge.recipes.deepseek import deepseek_v3_pretrain_config, set_deepseek_v3_pipeline_model_parallel_layout
1919
from megatron.bridge.training.config import GPTDatasetConfig, ConfigContainer
2020
from megatron.bridge.training.gpt_step import forward_step
2121
from megatron.bridge.training.pretrain import pretrain
22+
from megatron.bridge.training.tokenizers.config import TokenizerConfig
2223

2324
from callback import (
2425
MLPerfLoggingCallback,
@@ -113,31 +114,23 @@ def log_hyperparams(args, mbridge_config: ConfigContainer):
113114
mllogger.event(key=key, value=value)
114115

115116

117+
def get_tokenizer_config():
118+
return TokenizerConfig(
119+
tokenizer_type="HuggingFaceTokenizer",
120+
tokenizer_model="/tokenizer",
121+
hf_tokenizer_kwargs={"use_fast": True},
122+
)
123+
124+
116125
def create_config(args):
117126
"""Create the training configuration from arguments."""
118-
config = deepseek_v3_pretrain_config(
119-
pipeline_model_parallel_size=args.pipeline_parallel_size,
120-
virtual_pipeline_parallel_size=args.virtual_pipeline_parallel_size,
121-
)
127+
config = deepseek_v3_pretrain_config()
122128

123129
# Model parallelism configuration (hardcoded for DeepSeek V3)
124130
model_cfg = config.model
125-
model_cfg.tensor_model_parallel_size = args.tensor_parallel_size
126-
model_cfg.context_parallel_size = args.context_parallel_size
127-
model_cfg.expert_model_parallel_size = args.expert_model_parallel_size
128-
model_cfg.expert_tensor_parallel_size = args.expert_tensor_parallel_size
129-
model_cfg.sequence_parallel = args.tensor_parallel_size > 1
130-
model_cfg.seq_length = args.sequence_length
131-
model_cfg.recompute_modules = args.recompute_modules.split(",") if args.recompute_modules else []
132-
model_cfg.cuda_graph_implementation = args.cuda_graph_implementation
133-
model_cfg.cuda_graph_scope = args.cuda_graph_scope.split(",") if args.cuda_graph_scope else []
134-
135-
# MoE parameters (hardcoded for DeepSeek V3)
136-
model_cfg.moe_token_dispatcher_type = args.moe_token_dispatcher_type
137-
model_cfg.moe_grouped_gemm = args.moe_grouped_gemm
138-
model_cfg.moe_permute_fusion = args.moe_permute_fusion
139-
model_cfg.moe_router_fusion = args.moe_router_fusion
140-
model_cfg.moe_router_force_load_balancing = False
131+
model_cfg.pipeline_model_parallel_size = args.pipeline_parallel_size
132+
model_cfg.virtual_pipeline_model_parallel_size = args.virtual_pipeline_parallel_size
133+
set_deepseek_v3_pipeline_model_parallel_layout(model_cfg)
141134

142135
# Training configuration
143136
train_cfg = config.train
@@ -168,6 +161,9 @@ def create_config(args):
168161
seed=args.seed,
169162
)
170163

164+
# Tokenizer configuration
165+
config.tokenizer = get_tokenizer_config()
166+
171167
# Checkpoint configuration
172168
checkpoint_cfg = config.checkpoint
173169
checkpoint_cfg.load = "/checkpoint"

0 commit comments

Comments
 (0)