-
Notifications
You must be signed in to change notification settings - Fork 420
Expand file tree
/
Copy pathtest_moe_train_engine.py
More file actions
464 lines (413 loc) · 17.4 KB
/
test_moe_train_engine.py
File metadata and controls
464 lines (413 loc) · 17.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
import os
from pathlib import Path
import tempfile
import shutil
import time
from torch.distributed.device_mesh import init_device_mesh
import parametrize
import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor
from xtuner._testing import DeterministicDDPTestCase
from transformers import AutoTokenizer
from collections import defaultdict
from xtuner.v1.model.moe.moe import SequenceContext
from xtuner.v1.model.base import ModelItem
from xtuner.v1.loss.ce_loss import CELossConfig
from xtuner.v1.model.moe.qwen3 import Qwen3MoE30BA3Config, Qwen3MoEConfig
from xtuner.v1.config import FSDPConfig, LRConfig, AdamWConfig
from xtuner.v1.model.moe.moe import BalancingLossConfig, ZLossConfig
from xtuner.v1.engine.train_engine import TrainEngine
from torch.optim.lr_scheduler import LambdaLR
from xtuner.v1.utils import pad_to_max_length
from xtuner.v1.utils.device import get_device
from xtuner.v1.utils.test_utils import init_data_mesh
from xtuner.v1.module.decoder_layer.moe_decoder_layer import MoEDecoderLayer
from xtuner.v1.model import get_model_config_from_hf
# Qwen3 30B A3
QWEN3_MOE_PATH = os.environ["QWEN3_MOE_PATH"]
QWEN3_MOE_FOPE_PATH = os.environ["QWEN3_MOE_FOPE_PATH"]
DEVICE = get_device()
class TestMoEEngine(DeterministicDDPTestCase):
@parametrize.parametrize(
"device,ep_size,sp_size",
[
("cuda", 1, 1),
("cuda", 1, 2),
],
)
def test_moe_engine_train(self, device, ep_size, sp_size):
pg = self.create_pg(device)
moe_cfg = Qwen3MoE30BA3Config(
ep_size=ep_size,
balancing_loss_cfg=BalancingLossConfig(),
z_loss_cfg=ZLossConfig(),
compile_cfg=False,
)
optim_cfg: AdamWConfig = AdamWConfig()
lr_cfg: LRConfig = LRConfig()
fsdp_cfg: FSDPConfig = FSDPConfig(
cpu_offload=False,
ep_size=ep_size,
# hsdp_sharding_size=hsdp_sharding_size,
)
engine = TrainEngine(
model_cfg=moe_cfg, optim_cfg=optim_cfg, fsdp_cfg=fsdp_cfg
)
engine.from_hf(hf_path=QWEN3_MOE_PATH)
loss_cfg = CELossConfig()
total_steps = 1000
warmup_steps = total_steps * lr_cfg.warmup_ratio
def warmup_fn(x):
return x / warmup_steps if x < warmup_steps else 1
lr_scheduler = LambdaLR(engine.optimizer, warmup_fn)
tok = AutoTokenizer.from_pretrained(QWEN3_MOE_PATH)
txt = "根据国际地球自转和参考系服务机构的数据,今年夏天是自2020年以来第六次地球自转加速。7月9日将成为有史以来最短的一天,比平时短1.3到1.6毫秒。 "
input_ids = tok.encode(txt, return_tensors="pt").view(1, -1)
labels = input_ids.clone()
input_ids = input_ids[:, :-1]
labels = labels[:, 1:]
pack_len = 8192 - input_ids.shape[1]
input_ids = pad_to_max_length(input_ids, 0, max_length=8192)
labels = pad_to_max_length(labels, -100, max_length=8192)
losses = []
data_mesh = None
if sp_size > 1:
data_mesh = init_data_mesh(str(DEVICE), sp_size)
for _ in range(10):
seq_ctx = SequenceContext.from_input_ids((input_ids,), device=DEVICE)
labels = labels.to(DEVICE)
seq_ctx.num_padding = pack_len
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
loss_log = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
lr_scheduler.step()
losses.append(loss_log["reduced_llm_loss"])
losses_ref = torch.tensor([2.44, 2.44, 2.42, 2.41, 2.34, 2.33, 2.16, 2.13, 1.71, 1.55])
losses = torch.tensor(losses)
self._check_loss_curve(losses, losses_ref)
torch.cuda.empty_cache()
try:
dist.destroy_process_group(pg)
except:
pass
@parametrize.parametrize(
"device,ep_size,sp_size",
[
("cuda", 1, 1),
],
)
def test_moe_engine_train_freeze_routers(self, device, ep_size, sp_size):
pg = self.create_pg(device)
moe_cfg = Qwen3MoE30BA3Config(
ep_size=ep_size,
balancing_loss_cfg=BalancingLossConfig(),
z_loss_cfg=ZLossConfig(),
freeze_routers=True,
compile_cfg=False,
)
optim_cfg: AdamWConfig = AdamWConfig()
lr_cfg: LRConfig = LRConfig()
fsdp_cfg: FSDPConfig = FSDPConfig(
cpu_offload=False,
ep_size=ep_size,
# hsdp_sharding_size=hsdp_sharding_size,
)
engine = TrainEngine(
model_cfg=moe_cfg, optim_cfg=optim_cfg, fsdp_cfg=fsdp_cfg
)
engine.from_hf(hf_path=QWEN3_MOE_PATH)
loss_cfg = CELossConfig()
total_steps = 1000
warmup_steps = total_steps * lr_cfg.warmup_ratio
def warmup_fn(x):
return x / warmup_steps if x < warmup_steps else 1
lr_scheduler = LambdaLR(engine.optimizer, warmup_fn)
tok = AutoTokenizer.from_pretrained(QWEN3_MOE_PATH)
txt = "根据国际地球自转和参考系服务机构的数据,今年夏天是自2020年以来第六次地球自转加速。7月9日将成为有史以来最短的一天,比平时短1.3到1.6毫秒。 "
input_ids = tok.encode(txt, return_tensors="pt").view(1, -1)
labels = input_ids.clone()
input_ids = input_ids[:, :-1]
labels = labels[:, 1:]
pack_len = 8192 - input_ids.shape[1]
input_ids = pad_to_max_length(input_ids, 0, max_length=8192)
labels = pad_to_max_length(labels, -100, max_length=8192)
losses = []
data_mesh = None
if sp_size > 1:
data_mesh = init_data_mesh(str(DEVICE), sp_size)
# check the gradient and parameters of the routers
gate_means = defaultdict(list)
gate_stds = defaultdict(list)
for name, layer in engine.model.layers.items():
if isinstance(layer, MoEDecoderLayer):
self.assertFalse(layer.gate.weight.requires_grad)
self.assertTrue(layer.gate.weight.is_leaf)
gate_means[name].append(layer.gate.weight.full_tensor().mean())
gate_stds[name].append(layer.gate.weight.full_tensor().std())
for _ in range(4):
seq_ctx = SequenceContext.from_input_ids((input_ids,), device=DEVICE)
labels = labels.to(DEVICE)
seq_ctx.num_padding = pack_len
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
loss_log = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
lr_scheduler.step()
losses.append(loss_log["reduced_llm_loss"])
for name, layer in engine.model.layers.items():
if isinstance(layer, MoEDecoderLayer):
assert torch.equal(layer.gate.weight.full_tensor().mean(), gate_means[name][-1]), (
f"Mismatch in gate mean in layer {name}, {layer.gate.weight.full_tensor().mean()} and {gate_means[name][-1]}"
)
assert torch.equal(layer.gate.weight.full_tensor().std(), gate_stds[name][-1]), (
f"Mismatch in gate std in layer {name}, {layer.gate.weight.full_tensor().std()} and {gate_stds[name][-1]}"
)
gate_means[name].append(layer.gate.weight.full_tensor().mean())
gate_stds[name].append(layer.gate.weight.full_tensor().std())
torch.cuda.empty_cache()
try:
dist.destroy_process_group(pg)
except:
pass
@parametrize.parametrize(
"device,ep_size,hsdp_sharding_size",
[
("cuda", 1, 8), # todo: test ep8 and hsdp, OOM in 8 gpus
],
)
def test_save_and_load(self, device, ep_size, hsdp_sharding_size):
pg = self.create_pg(device)
temp_dir = tempfile.mkdtemp()
if dist.get_rank() == 0:
temp_dir = [temp_dir]
else:
temp_dir = [None]
dist.broadcast_object_list(temp_dir, src=0)
temp_dir = temp_dir[0]
moe_cfg = Qwen3MoE30BA3Config(
ep_size=ep_size,
balancing_loss_cfg=BalancingLossConfig(),
z_loss_cfg=ZLossConfig(),
compile_cfg=False,
)
optim_cfg: AdamWConfig = AdamWConfig()
fsdp_cfg: FSDPConfig = FSDPConfig(
cpu_offload=False,
ep_size=ep_size,
hsdp_sharding_size=hsdp_sharding_size,
)
engine = TrainEngine(
model_cfg=moe_cfg,
optim_cfg=optim_cfg,
fsdp_cfg=fsdp_cfg,
)
engine.init_model_weights()
engine.from_hf(hf_path=QWEN3_MOE_PATH)
engine.save_hf(
hf_dir=temp_dir,
save_dtype=torch.bfloat16,
)
dist.barrier()
time.sleep(1)
engine2 = TrainEngine(
model_cfg=moe_cfg,
optim_cfg=optim_cfg,
fsdp_cfg=fsdp_cfg,
)
engine2.from_hf(hf_path=temp_dir)
state_dict = engine.model.state_dict()
state_dict2 = engine2.model.state_dict()
for key, val in state_dict.items():
val2 = state_dict2[key]
val = val.full_tensor().bfloat16()
val2 = val2.full_tensor().bfloat16()
self.assertTrue(torch.equal(val, val2[:val.shape[0]]),
f"Mismatch in {key} between bf16 and fp8, {val} and {val2[:val.shape[0]]}")
if dist.get_rank() == 0:
shutil.rmtree(temp_dir)
torch.cuda.empty_cache()
try:
dist.destroy_process_group(pg)
except:
pass
@parametrize.parametrize(
"device,dispatcher,ep_size,load_from_type",
[
("cuda", None, 1, "qwen3"),
("cuda", "all2all", 4, "qwen3"),
("cuda", "all2all", 8, "qwen3"),
("cuda", None, 1, "qwen3_fope"),
("cuda", "all2all", 4, "qwen3_fope"),
("cuda", "all2all", 8, "qwen3_fope"),
],
)
def test_checkpoint_save_load(self, device, dispatcher, ep_size, load_from_type):
pg = self.create_pg(device)
print(f"dist.get_rank(): {dist.get_rank()}")
os.environ["LOCAL_RANK"] = str(dist.get_rank())
torch.accelerator.set_device_index(int(dist.get_rank()))
assert load_from_type in ["qwen3", "qwen3_fope"]
load_from = Path(QWEN3_MOE_PATH) if load_from_type == "qwen3" else Path(QWEN3_MOE_FOPE_PATH)
with tempfile.TemporaryDirectory() as tmpdir:
tiny_model = True
# 1. create
engine = create_engine_from_hf(load_from, dispatcher, ep_size, tiny=tiny_model)
# 2. operate
syncdir = [tmpdir]
dist.broadcast_object_list(syncdir, src=0)
tmpdir = Path(syncdir[0])
engine.from_hf(load_from, strict=not tiny_model)
dist.barrier()
model_dir, optimizer_dir = tmpdir / "model", tmpdir / "optimizer"
engine.save_dcp(model_dir=model_dir, optimizer_dir=optimizer_dir)
dist.barrier()
time.sleep(1)
engine2 = create_engine_from_hf(load_from, dispatcher, ep_size, tiny=tiny_model)
engine2.init_model_weights()
engine2.load_dcp(model_dir=model_dir, optimizer_dir=optimizer_dir)
# 3. check
# check the model state
state_dict = engine.model.state_dict()
state_dict2 = engine2.model.state_dict()
assert len(state_dict) == len(state_dict2)
if load_from_type == "qwen3_fope":
assert 'rotary_emb.sin_coef' in state_dict
assert 'rotary_emb.cos_coef' in state_dict
for key, val in state_dict.items():
val2 = state_dict2[key]
val = val._local_tensor if isinstance(val, DTensor) else val
val2 = val2._local_tensor if isinstance(val2, DTensor) else val2
self.assertTrue(torch.equal(val, val2),
f"Mismatch in {key}, val shape {val.shape} and val2 shape {val2.shape}")
# check the optimizer state
opt_state = engine.optimizer.state_dict()['state']
opt_state2 = engine2.optimizer.state_dict()['state']
# state_dict['state'] = {
# 0: {'momentum_buffer': tensor(...), ...},
# 1: {'momentum_buffer': tensor(...), ...},
# 2: {'momentum_buffer': tensor(...), ...},
# 3: {'momentum_buffer': tensor(...), ...}
# },
assert len(opt_state) == len(opt_state2)
assert len(opt_state) != 0
for param_id, cur_state_dict in opt_state.items():
cur_state_dict2 = opt_state2[param_id]
assert len(cur_state_dict) == len(cur_state_dict2)
assert len(cur_state_dict) != 0
for state_key, val in cur_state_dict.items():
val2 = cur_state_dict2[state_key]
val = val._local_tensor if isinstance(val, DTensor) else val
val2 = val2._local_tensor if isinstance(val2, DTensor) else val2
self.assertTrue(torch.equal(val, val2), f"Mismatch in {key}, val shape {val.shape} and val2 shape {val2.shape}")
torch.cuda.empty_cache()
try:
dist.destroy_process_group(pg)
except:
pass
@parametrize.parametrize(
"device",
[
("cuda",),
],
)
def test_load_optimizer_with_new_lr(self, device):
pg = self.create_pg(device)
temp_dir = tempfile.mkdtemp()
if dist.get_rank() == 0:
temp_dir = [temp_dir]
else:
temp_dir = [None]
dist.broadcast_object_list(temp_dir, src=0)
temp_dir = Path(temp_dir[0])
model_dir = temp_dir / "model"
optimizer_dir = temp_dir / "optimizer"
moe_cfg = Qwen3MoE30BA3Config(
num_hidden_layers=2,
)
lr1 = 1e-4
eps1 = 1e-7
optim_cfg: AdamWConfig = AdamWConfig(lr=lr1, eps=eps1)
fsdp_cfg: FSDPConfig = FSDPConfig()
engine = TrainEngine(
model_cfg=moe_cfg,
optim_cfg=optim_cfg,
fsdp_cfg=fsdp_cfg,
)
engine.init_model_weights()
engine.save_dcp(model_dir=model_dir, optimizer_dir=optimizer_dir)
dist.barrier()
time.sleep(1)
lr2 = 1e-3
eps2 = 1e-5
optim_cfg2: AdamWConfig = AdamWConfig(lr=lr2, eps=eps2)
engine2 = TrainEngine(
model_cfg=moe_cfg,
optim_cfg=optim_cfg2,
fsdp_cfg=fsdp_cfg,
)
engine2.load_dcp(model_dir=model_dir, optimizer_dir=optimizer_dir, load_args=False)
# print(f"len(engine.optimizer.state), len(engine2.optimizer.state): {len(engine.optimizer.state)}, {len(engine2.optimizer.state)}")
assert len(engine.optimizer.state) == len(engine2.optimizer.state)
assert len(engine.optimizer.state) != 0
for param_group in engine2.optimizer.param_groups:
# print(f"param_group['lr']: {param_group['lr']}")
assert param_group['lr'] == lr2
assert param_group['eps'] == eps2
lr3 = 1e-1
eps3 = 1e-3
optim_cfg3 = AdamWConfig(lr=lr3, eps=eps3)
engine3 = TrainEngine(
model_cfg=moe_cfg,
optim_cfg=optim_cfg3,
fsdp_cfg=fsdp_cfg,
)
engine3.load_dcp(model_dir=model_dir, optimizer_dir=optimizer_dir, load_states=False)
assert len(engine3.optimizer.state) == 0
for param_group in engine3.optimizer.param_groups:
assert param_group['lr'] == lr1
assert param_group['eps'] == eps1
torch.cuda.empty_cache()
try:
dist.destroy_process_group(pg)
except:
pass
@property
def world_size(self) -> int:
return int(os.getenv("XTUNER_TEST_WORLD_SIZE", "8"))
@property
def destroy_pg_upon_exit(self) -> bool:
return False
def create_engine_from_hf(load_from: Path, dispatcher: str | None, ep_size: int, tiny: bool = False):
moe_cfg : Qwen3MoEConfig = get_model_config_from_hf(load_from)
moe_cfg.dispatcher = dispatcher
moe_cfg.ep_size = ep_size
moe_cfg.compile_cfg = False
if tiny:
moe_cfg.num_hidden_layers = 2
optim_cfg: AdamWConfig = AdamWConfig()
fsdp_cfg: FSDPConfig = FSDPConfig(
cpu_offload=False,
ep_size=ep_size,
)
engine = TrainEngine(
model_cfg=moe_cfg,
optim_cfg=optim_cfg,
fsdp_cfg=fsdp_cfg,
)
return engine