Skip to content

Commit f2ed869

Browse files
committed
deepseek v3 pretrain
1 parent 9940a17 commit f2ed869

File tree

9 files changed

+582
-36
lines changed

9 files changed

+582
-36
lines changed

examples/deepseek_v3/exp_pretrain.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ workspace: ./output
66
platform:
77
config: platform_azure.yaml
88
overrides:
9-
master_sink_level: DEBUG
9+
master_sink_level: INFO
1010

1111
modules:
1212
pre_trainer:
@@ -16,7 +16,7 @@ modules:
1616
overrides:
1717
# log
1818
wandb_project: "Primus_DeepSeekV3_Pretrain"
19-
# disable_wandb: false
19+
disable_wandb: false
2020
stderr_sink_level: DEBUG
2121

2222
# debug
@@ -60,3 +60,4 @@ modules:
6060
save_interval: 20000
6161
no_save_optim: null
6262
no_save_rng: null
63+
disable_last_saving: true

examples/deepseek_v3/pretrain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@
3232
log_init(primus_cfg, trainer.platform)
3333

3434
trainer.init()
35-
# trainer.run()
35+
trainer.run()

primus/configs/models/megatron/deepseek_v3_base.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ moe_router_bias_update_rate: 1.0e-3
2424
moe_router_load_balancing_type: seq_aux_loss
2525
moe_token_dispatcher_type: alltoall
2626
moe_shared_expert_overlap: true
27-
moe_aux_loss_coeff: ${moe_aux_loss_coeff:1.0e-2}
27+
moe_aux_loss_coeff: 1.0e-2
2828

2929
# parallel and optimization
3030
expert_model_parallel_size: 1

primus/configs/modules/megatron/trainer_base.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ train_iters: null
179179
eval_iters: 32
180180
eval_interval: 2000
181181
skip_train: false
182+
train_sync_interval: null # int
182183

183184
adlr_autoresume: false
184185
adlr_autoresume_interval: 1000
@@ -228,6 +229,7 @@ profile: false
228229
profile_ranks: [0]
229230
profile_step_end: 12
230231
profile_step_start: 10
232+
iterations_to_skip: null
231233
result_rejected_tracker_filename: null
232234
enable_gloo_process_groups: true
233235
record_memory_history: false
@@ -351,3 +353,4 @@ parallel_output: false
351353

352354
enable_ft_package: false
353355
calc_ft_timeouts: false
356+
run_workload_inspector_server: false

primus/modules/base_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def setup_worker_logger(self, rank, world_size):
107107

108108
# monkey patch print function of builtins
109109
self.original_print = builtins.print
110+
# builtins.print = log_rank_all
110111
builtins.print = debug_rank_all
111112

112113
# disable all logging handlers

primus/modules/module_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,31 @@ def log_rank_0(msg, *args, **kwargs):
2727
log_func(msg, module_name, function_name, line)
2828

2929

30+
def log_rank_last(msg, *args, **kwargs):
31+
log_func = logger.info_with_caller
32+
33+
caller = inspect.stack()[1]
34+
caller_frame = caller.frame
35+
function_name = caller_frame.f_code.co_name
36+
module_name = caller_frame.f_globals["__name__"].split(".")[-1]
37+
line = caller.lineno
38+
39+
if _rank == _world_size - 1:
40+
log_func(msg, module_name, function_name, line)
41+
42+
43+
def log_rank_all(msg, *args, **kwargs):
44+
log_func = logger.info_with_caller
45+
46+
caller = inspect.stack()[1]
47+
caller_frame = caller.frame
48+
function_name = caller_frame.f_code.co_name
49+
module_name = caller_frame.f_globals["__name__"].split(".")[-1]
50+
line = caller.lineno
51+
52+
log_func(msg, module_name, function_name, line)
53+
54+
3055
def log_kv_rank_0(key, value):
3156
log_func = logger.log_kv_with_caller
3257

primus/modules/trainer/base_trainer.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,22 @@
66

77
from abc import ABC, abstractmethod
88

9+
import torch
10+
from megatron.core.models.gpt import GPTModel
911

10-
class BaseTrainer(ABC):
11-
@abstractmethod
12-
def get_batch_func(self):
13-
raise NotImplementedError
1412

13+
class BaseTrainer(ABC):
14+
# def get_batch_func(self):
1515
@abstractmethod
16-
def get_loss_func(self):
16+
def get_batch(self, data_iterator):
1717
raise NotImplementedError
1818

19+
# def get_loss_func(self):
1920
@abstractmethod
20-
def build_dataset_and_tokenizer(self):
21+
def loss_func(self, loss_mask: torch.Tensor, output_tensor: torch.Tensor):
2122
raise NotImplementedError
2223

24+
# def get_forward_step_func(self):
2325
@abstractmethod
24-
def get_forward_step_func(self):
26+
def forward_step(self, data_iterator, model: GPTModel):
2527
raise NotImplementedError

primus/modules/trainer/megatron/pre_trainer.py

Lines changed: 111 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,17 @@
44
# See LICENSE for license information.
55
#################################################################################
66

7+
from functools import partial
8+
9+
import torch
10+
from megatron.core import mpu
11+
from megatron.core.models.gpt import GPTModel
12+
from megatron.core.rerun_state_machine import get_rerun_state_machine
13+
from megatron.core.utils import StragglerDetector
14+
from megatron.training import get_args, get_timers
15+
from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank
16+
17+
stimer = StragglerDetector()
718

819
from .trainer import MegatronTrainer
920

@@ -13,14 +24,106 @@ def __init__(self, *args, **kwargs):
1324
kwargs["module_name"] = "pre_trainer"
1425
super().__init__(*args, **kwargs)
1526

16-
def get_batch_func(self):
17-
raise NotImplementedError
27+
def get_batch(self, data_iterator):
28+
"""Generate a batch."""
29+
30+
# TODO: this is pretty hacky, find a better way
31+
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
32+
return None, None, None, None, None
33+
34+
# get batches based on the TP rank you are on
35+
batch = get_batch_on_this_tp_rank(data_iterator)
36+
37+
# slice batch along sequence dimension for context parallelism
38+
batch = get_batch_on_this_cp_rank(batch)
39+
40+
return batch.values()
41+
42+
def loss_func(self, loss_mask: torch.Tensor, output_tensor: torch.Tensor):
43+
"""Loss function.
44+
45+
Args:
46+
loss_mask (torch.Tensor): Used to mask out some portions of the loss
47+
output_tensor (torch.Tensor): The tensor with the losses
48+
49+
Returns:
50+
the loss scalar for this micro-batch
51+
the number of non-padded tokens in this microbatch
52+
a dict containing reporting metrics on the loss and number of tokens across
53+
the data parallel ranks
54+
"""
55+
args = get_args()
56+
57+
losses = output_tensor.float()
58+
loss_mask = loss_mask.view(-1).float()
59+
total_tokens = loss_mask.sum()
60+
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
61+
62+
if args.context_parallel_size > 1:
63+
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
64+
65+
# Check individual rank losses are not NaN prior to DP all-reduce.
66+
rerun_state_machine = get_rerun_state_machine()
67+
if args.check_for_nan_in_loss_and_grad:
68+
rerun_state_machine.validate_result(
69+
result=loss[0],
70+
rejection_func=torch.isnan,
71+
message="found NaN in local forward loss calculation",
72+
tolerance=0.0, # forward pass calculations are determinisic
73+
fatal=True,
74+
)
75+
rerun_state_machine.validate_result(
76+
result=loss[0],
77+
rejection_func=torch.isinf,
78+
message="found Inf in local forward loss calculation",
79+
tolerance=0.0, # forward pass calculations are determinisic
80+
fatal=True,
81+
)
82+
# Check for spiky loss
83+
if args.check_for_spiky_loss:
84+
rerun_state_machine.validate_result(
85+
result=loss[0],
86+
rejection_func=partial(
87+
rerun_state_machine.is_unexpectedly_large,
88+
threshold=SPIKY_LOSS_FACTOR,
89+
context="loss",
90+
),
91+
message="Spiky loss",
92+
tolerance=0.0, # forward pass calculations are determinisic
93+
fatal=False,
94+
)
95+
# Reduce loss for logging.
96+
reporting_loss = loss.clone().detach()
97+
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
98+
99+
# loss[0] is a view of loss, so it has ._base not None, which triggers assert error
100+
# in core/pipeline_parallel/schedule.py::deallocate_output_tensor, calling .clone()
101+
# on loss[0] fixes this
102+
local_num_tokens = loss[1].clone().detach().to(torch.int)
103+
return (
104+
loss[0].clone(),
105+
local_num_tokens,
106+
{"lm loss": (reporting_loss[0], reporting_loss[1])},
107+
)
108+
109+
def forward_step(self, data_iterator, model: GPTModel):
110+
"""Forward training step.
111+
112+
Args:
113+
data_iterator : Input data iterator
114+
model (GPTModel): The GPT Model
115+
"""
116+
get_args()
117+
timers = get_timers()
18118

19-
def get_loss_func(self):
20-
raise NotImplementedError
119+
# Get the batch.
120+
timers("batch-generator", log_level=2).start()
121+
global stimer
122+
with stimer(bdata=True):
123+
tokens, labels, loss_mask, attention_mask, position_ids = self.get_batch(data_iterator)
124+
timers("batch-generator").stop()
21125

22-
def build_dataset_and_tokenizer(self):
23-
raise NotImplementedError
126+
with stimer:
127+
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
24128

25-
def get_forward_step_func(self):
26-
raise NotImplementedError
129+
return output_tensor, partial(self.loss_func, loss_mask)

0 commit comments

Comments
 (0)