|
15 | 15 | import argparse |
16 | 16 | import os |
17 | 17 |
|
18 | | -from megatron.bridge.recipes.deepseek import deepseek_v3_pretrain_config |
| 18 | +from megatron.bridge.recipes.deepseek import deepseek_v3_pretrain_config, set_deepseek_v3_pipeline_model_parallel_layout |
19 | 19 | from megatron.bridge.training.config import GPTDatasetConfig, ConfigContainer |
20 | 20 | from megatron.bridge.training.gpt_step import forward_step |
21 | 21 | from megatron.bridge.training.pretrain import pretrain |
| 22 | +from megatron.bridge.training.tokenizers.config import TokenizerConfig |
22 | 23 |
|
23 | 24 | from callback import ( |
24 | 25 | MLPerfLoggingCallback, |
@@ -113,31 +114,23 @@ def log_hyperparams(args, mbridge_config: ConfigContainer): |
113 | 114 | mllogger.event(key=key, value=value) |
114 | 115 |
|
115 | 116 |
|
| 117 | +def get_tokenizer_config(): |
| 118 | + return TokenizerConfig( |
| 119 | + tokenizer_type="HuggingFaceTokenizer", |
| 120 | + tokenizer_model="/tokenizer", |
| 121 | + hf_tokenizer_kwargs={"use_fast": True}, |
| 122 | + ) |
| 123 | + |
| 124 | + |
116 | 125 | def create_config(args): |
117 | 126 | """Create the training configuration from arguments.""" |
118 | | - config = deepseek_v3_pretrain_config( |
119 | | - pipeline_model_parallel_size=args.pipeline_parallel_size, |
120 | | - virtual_pipeline_parallel_size=args.virtual_pipeline_parallel_size, |
121 | | - ) |
| 127 | + config = deepseek_v3_pretrain_config() |
122 | 128 |
|
123 | 129 | # Model parallelism configuration (hardcoded for DeepSeek V3) |
124 | 130 | model_cfg = config.model |
125 | | - model_cfg.tensor_model_parallel_size = args.tensor_parallel_size |
126 | | - model_cfg.context_parallel_size = args.context_parallel_size |
127 | | - model_cfg.expert_model_parallel_size = args.expert_model_parallel_size |
128 | | - model_cfg.expert_tensor_parallel_size = args.expert_tensor_parallel_size |
129 | | - model_cfg.sequence_parallel = args.tensor_parallel_size > 1 |
130 | | - model_cfg.seq_length = args.sequence_length |
131 | | - model_cfg.recompute_modules = args.recompute_modules.split(",") if args.recompute_modules else [] |
132 | | - model_cfg.cuda_graph_implementation = args.cuda_graph_implementation |
133 | | - model_cfg.cuda_graph_scope = args.cuda_graph_scope.split(",") if args.cuda_graph_scope else [] |
134 | | - |
135 | | - # MoE parameters (hardcoded for DeepSeek V3) |
136 | | - model_cfg.moe_token_dispatcher_type = args.moe_token_dispatcher_type |
137 | | - model_cfg.moe_grouped_gemm = args.moe_grouped_gemm |
138 | | - model_cfg.moe_permute_fusion = args.moe_permute_fusion |
139 | | - model_cfg.moe_router_fusion = args.moe_router_fusion |
140 | | - model_cfg.moe_router_force_load_balancing = False |
| 131 | + model_cfg.pipeline_model_parallel_size = args.pipeline_parallel_size |
| 132 | + model_cfg.virtual_pipeline_model_parallel_size = args.virtual_pipeline_parallel_size |
| 133 | + set_deepseek_v3_pipeline_model_parallel_layout(model_cfg) |
141 | 134 |
|
142 | 135 | # Training configuration |
143 | 136 | train_cfg = config.train |
@@ -168,6 +161,9 @@ def create_config(args): |
168 | 161 | seed=args.seed, |
169 | 162 | ) |
170 | 163 |
|
| 164 | + # Tokenizer configuration |
| 165 | + config.tokenizer = get_tokenizer_config() |
| 166 | + |
171 | 167 | # Checkpoint configuration |
172 | 168 | checkpoint_cfg = config.checkpoint |
173 | 169 | checkpoint_cfg.load = "/checkpoint" |
|
0 commit comments