Skip to content

Commit f8521b3

Browse files
authored
align the deepseek v3 training parameters (#7)
1 parent 4690a5c commit f8521b3

File tree

5 files changed

+19
-20
lines changed

5 files changed

+19
-20
lines changed

examples/deepseek_v3/exp_pretrain.yaml

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ modules:
1616
overrides:
1717
# log
1818
wandb_project: "Primus_DeepSeekV3_Pretrain"
19-
disable_wandb: false
19+
# disable_wandb: false
2020
stderr_sink_level: DEBUG
2121

2222
# debug
23-
num_layers: 4
23+
# num_layers: 4
2424

2525
# hyber parameters
2626
train_iters: 10
@@ -33,6 +33,12 @@ modules:
3333
lr_warmup_iters: 2
3434
lr_decay_iters: null
3535
lr_decay_style: cosine
36+
weight_decay: 0.1
37+
adam_beta1: 0.9
38+
adam_beta2: 0.95
39+
eod_mask_loss: true
40+
init_method_std: 0.008
41+
norm_epsilon: 1.0e-6
3642

3743
# parallel
3844
tensor_model_parallel_size: 1
@@ -41,8 +47,8 @@ modules:
4147

4248
# data
4349
train_data_path: /home/azureuser/tas-public/data/deepseek-datasets/mmap_deepseekv2_datasets_text_document
44-
valid_data_path: /home/azureuser/tas-public/data/deepseek-datasets/mmap_deepseekv2_datasets_text_document
45-
test_data_path: /home/azureuser/tas-public/data/deepseek-datasets/mmap_deepseekv2_datasets_text_document
50+
valid_data_path: null
51+
test_data_path: null
4652

4753
# fusion
4854
# 20250317: need latest apex in docker image
@@ -61,3 +67,4 @@ modules:
6167
no_save_optim: null
6268
no_save_rng: null
6369
disable_last_saving: true
70+
ckpt_format: torch

primus/configs/modules/megatron/trainer_base.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,7 @@ ddp_bucket_size: null # int
145145
ddp_pad_buckets_for_high_nccl_busbw: false
146146
ddp_average_in_collective: false
147147
overlap_grad_reduce: false
148-
delay_grad_reduce: true
149148
overlap_param_gather: false
150-
delay_param_gather: false
151149
overlap_param_gather_with_optimizer_step: false
152150
align_param_gather: true
153151
scatter_gather_tensors_in_pipeline: true

primus/modules/trainer/base_trainer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,14 @@
1111

1212

1313
class BaseTrainer(ABC):
14-
# def get_batch_func(self):
1514
@abstractmethod
1615
def get_batch(self, data_iterator):
17-
raise NotImplementedError
16+
pass
1817

19-
# def get_loss_func(self):
2018
@abstractmethod
2119
def loss_func(self, loss_mask: torch.Tensor, output_tensor: torch.Tensor):
22-
raise NotImplementedError
20+
pass
2321

24-
# def get_forward_step_func(self):
2522
@abstractmethod
2623
def forward_step(self, data_iterator, model: GPTModel):
27-
raise NotImplementedError
24+
pass

primus/modules/trainer/megatron/sft_trainer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,11 @@ def __init__(self, *args, **kwargs):
66
kwargs["module_name"] = "sft_trainer"
77
super().__init__(*args, **kwargs)
88

9-
def get_batch_func(self):
9+
def get_batch(self, data_iterator):
1010
raise NotImplementedError
1111

12-
def get_loss_func(self):
12+
def loss_func(self, loss_mask: torch.Tensor, output_tensor: torch.Tensor):
1313
raise NotImplementedError
1414

15-
def build_dataset_and_tokenizer(self):
16-
raise NotImplementedError
17-
18-
def get_forward_step_func(self):
15+
def forward_step(self, data_iterator, model: GPTModel):
1916
raise NotImplementedError

primus/modules/trainer/megatron/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import os
33
import time
44
from contextlib import nullcontext
5-
from datetime import datetime
65
from typing import Union
76

87
import megatron
@@ -1721,7 +1720,8 @@ def training_log(
17211720
writer.add_scalar("iteration-time", elapsed_time_per_iteration, iteration)
17221721
if wandb_writer:
17231722
wandb_writer.log({"iteration-time": elapsed_time_per_iteration}, iteration)
1724-
log_string = f" [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
1723+
# log_string = f" [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
1724+
log_string = f""
17251725
if hasattr(self, "episode_count") and self.episode_count is not None:
17261726
log_string += f" episode {self.episode_count} |"
17271727
log_string += " iteration {:8d}/{:8d} |".format(iteration, args.train_iters)

0 commit comments

Comments
 (0)