Skip to content

Commit 3544339

Browse files
committed
update LR warmup configs (lr_warmup_iters < train_iters) in config files
1 parent 51ffd59 commit 3544339

File tree

3 files changed

+213
-3
lines changed

3 files changed

+213
-3
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
work_group: ${PRIMUS_TEAM:amd}
2+
user_name: ${PRIMUS_USER:root}
3+
exp_name: ${PRIMUS_EXP_NAME:mamba_370M-pretrain}
4+
workspace: ${PRIMUS_WORKSPACE:./output}
5+
6+
modules:
7+
pre_trainer:
8+
framework: megatron
9+
config: pre_trainer.yaml
10+
11+
# model to run
12+
model: mamba_370M.yaml
13+
overrides:
14+
# log
15+
wandb_project: "Primus_Mamba_Pretrain"
16+
# disable_wandb: false
17+
# disable_tensorboard: false
18+
stderr_sink_level: DEBUG
19+
20+
eval_iters: 0
21+
22+
log_avg_skip_iterations: 2
23+
log_avg_reset_interval: 50
24+
25+
train_iters: 50
26+
micro_batch_size: 4
27+
global_batch_size: 256
28+
29+
seq_length: 2048
30+
max_position_embeddings: 2048
31+
32+
lr: 3.0e-4
33+
min_lr: 0.0
34+
lr_warmup_iters: 5
35+
lr_decay_iters: null
36+
lr_decay_style: cosine
37+
weight_decay: 0.1
38+
adam_beta1: 0.9
39+
adam_beta2: 0.95
40+
eod_mask_loss: true
41+
init_method_std: 0.02
42+
norm_epsilon: 1.0e-5
43+
44+
# Mamba-specific: must provide spec
45+
spec: ['megatron.core.models.mamba.mamba_layer_specs', 'mamba_stack_spec']
46+
47+
# Tokenizer
48+
tokenizer_type: HuggingFaceTokenizer
49+
tokenizer_model: meta-llama/Llama-3.2-1B
50+
51+
# Mamba SSM parameters
52+
is_hybrid_model: false
53+
hybrid_attention_ratio: 0.0
54+
hybrid_mlp_ratio: 0.0
55+
mamba_state_dim: 16
56+
mamba_head_dim: 64
57+
mamba_num_groups: 8
58+
59+
# parallel
60+
tensor_model_parallel_size: 1
61+
pipeline_model_parallel_size: 1
62+
expert_model_parallel_size: 1
63+
overlap_grad_reduce: true
64+
overlap_param_gather: true
65+
gradient_accumulation_fusion: false
66+
67+
# data
68+
mock_data: true
69+
train_data_path: null
70+
valid_data_path: null
71+
test_data_path: null
72+
73+
# ckpt
74+
finetune: false
75+
auto_continue_train: false
76+
load: null
77+
no_load_optim: null
78+
no_load_rng: null
79+
save: null
80+
save_interval: 20000
81+
no_save_optim: null
82+
no_save_rng: null
83+
disable_last_saving: true
84+
ckpt_format: torch
85+
86+
# Turbo - may need to disable for Mamba if not supported
87+
enable_primus_turbo: false
88+
use_turbo_attention: false
89+
use_turbo_grouped_mlp: false
90+
91+
# Cross entropy flags
92+
# cross_entropy_fusion_impl: "native"
93+
# cross_entropy_loss_fusion: false
94+
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
work_group: ${PRIMUS_TEAM:amd}
2+
user_name: ${PRIMUS_USER:root}
3+
exp_name: ${PRIMUS_EXP_NAME:mamba_hybrid_2.8B-pretrain}
4+
workspace: ${PRIMUS_WORKSPACE:./output}
5+
6+
modules:
7+
pre_trainer:
8+
framework: megatron
9+
config: pre_trainer.yaml
10+
11+
# model to run
12+
model: mamba_hybrid_2.8B.yaml
13+
overrides:
14+
# log
15+
wandb_project: "Primus_Mamba_Hybrid_Pretrain"
16+
stderr_sink_level: DEBUG
17+
18+
eval_iters: 0
19+
20+
log_avg_skip_iterations: 2
21+
log_avg_reset_interval: 50
22+
23+
train_iters: 100
24+
micro_batch_size: 2
25+
global_batch_size: 128
26+
27+
seq_length: 4096
28+
max_position_embeddings: 4096
29+
30+
lr: 2.0e-4
31+
min_lr: 2.0e-5
32+
lr_warmup_iters: 10
33+
lr_decay_iters: 100
34+
lr_decay_style: cosine
35+
weight_decay: 0.1
36+
adam_beta1: 0.9
37+
adam_beta2: 0.95
38+
eod_mask_loss: true
39+
init_method_std: 0.02
40+
norm_epsilon: 1.0e-5
41+
42+
# Mamba-specific: must provide spec
43+
spec: ['megatron.core.models.mamba.mamba_layer_specs', 'mamba_stack_spec']
44+
45+
# Tokenizer
46+
tokenizer_type: HuggingFaceTokenizer
47+
tokenizer_model: meta-llama/Llama-3.2-1B
48+
49+
# Hybrid Mamba+Attention parameters
50+
is_hybrid_model: true
51+
hybrid_attention_ratio: 0.125
52+
hybrid_mlp_ratio: 0.0
53+
mamba_state_dim: 16
54+
mamba_head_dim: 64
55+
mamba_num_groups: 8
56+
57+
# parallel
58+
tensor_model_parallel_size: 2
59+
pipeline_model_parallel_size: 1
60+
expert_model_parallel_size: 1
61+
overlap_grad_reduce: true
62+
overlap_param_gather: true
63+
gradient_accumulation_fusion: true
64+
65+
# data
66+
mock_data: true
67+
train_data_path: null
68+
valid_data_path: null
69+
test_data_path: null
70+
71+
# ckpt
72+
finetune: false
73+
auto_continue_train: false
74+
load: null
75+
save: null
76+
save_interval: 10000
77+
disable_last_saving: true
78+
ckpt_format: torch
79+
80+
# Turbo - disable for Mamba layers, but attention layers may benefit
81+
enable_primus_turbo: false
82+
use_turbo_attention: false
83+
use_turbo_grouped_mlp: false
84+

primus/modules/trainer/megatron/pre_trainer.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,25 @@ def forward_step(self, data_iterator, model: GPTModel, return_schedule_plan=Fals
235235
DataLoaderStore.push(data_iterator, h2d_stream=False)
236236
tokens, labels, loss_mask, attention_mask, position_ids = DataLoaderStore.pop()
237237

238+
# Determine if model supports loss_mask parameter
239+
# MambaModel doesn't accept loss_mask in forward(), while GPTModel does
240+
model_type = getattr(args, 'model_type', 'gpt')
241+
supports_loss_mask = (model_type != 'mamba')
242+
243+
# Alternative check: inspect the actual model class
244+
# This is a fallback in case model_type isn't set correctly
245+
if not supports_loss_mask:
246+
# Already determined it's Mamba, no need for further checks
247+
pass
248+
else:
249+
# Double-check by inspecting the actual model object
250+
from megatron.core.models.mamba import MambaModel
251+
from megatron.core.utils import get_attr_wrapped_model
252+
253+
actual_model = get_attr_wrapped_model(model, 'forward', return_model_obj=True)
254+
if isinstance(actual_model, MambaModel):
255+
supports_loss_mask = False
256+
238257
with stimer:
239258
if return_schedule_plan:
240259
assert (
@@ -256,17 +275,30 @@ def forward_step(self, data_iterator, model: GPTModel, return_schedule_plan=Fals
256275
TransformerModelChunkSchedulePlan,
257276
)
258277

278+
schedule_kwargs = {"labels": labels}
279+
if supports_loss_mask:
280+
schedule_kwargs["loss_mask"] = loss_mask
281+
259282
schedule_plan = TransformerModelChunkSchedulePlan(
260-
model, tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
283+
model, tokens, position_ids, attention_mask, **schedule_kwargs
261284
)
262285
else:
286+
schedule_kwargs = {"labels": labels}
287+
if supports_loss_mask:
288+
schedule_kwargs["loss_mask"] = loss_mask
289+
263290
schedule_plan = model.build_schedule_plan(
264-
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
291+
tokens, position_ids, attention_mask, **schedule_kwargs
265292
)
266293
return schedule_plan, partial(self.loss_func, loss_mask)
267294
else:
295+
# Build forward kwargs based on model type
296+
forward_kwargs = {"labels": labels}
297+
if supports_loss_mask:
298+
forward_kwargs["loss_mask"] = loss_mask
299+
268300
output_tensor = model(
269-
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
301+
tokens, position_ids, attention_mask, **forward_kwargs
270302
)
271303

272304
return output_tensor, partial(self.loss_func, loss_mask)

0 commit comments

Comments
 (0)