-
Notifications
You must be signed in to change notification settings - Fork 419
Expand file tree
/
Copy pathtest_qwen3_dense.py
More file actions
284 lines (246 loc) · 11.1 KB
/
test_qwen3_dense.py
File metadata and controls
284 lines (246 loc) · 11.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import os
import json
import parametrize
import torch
from torch.testing._internal.common_distributed import DistributedTestBase
import torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer
import tempfile
from pathlib import Path
from safetensors import safe_open
from xtuner.v1.module.attention import MHAConfig
from xtuner.v1.data_proto import SequenceContext
from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig
from xtuner.v1.config import FSDPConfig
from xtuner.v1.utils.compile import maybe_compile
from xtuner.v1.loss.ce_loss import CELossConfig
from xtuner._testing import patch_hf_rms_norm, DeterministicDDPTestCase
# Qwen3 8B
QWEN3_PATH = os.environ["QWEN3_PATH"]
class TestQwen3Dense(DeterministicDDPTestCase):
@parametrize.parametrize(
"device,tp_size,compile,tol,loss_class",
[
("cuda", 1, False, 1e-2, "cross_entropy"),
("cuda", 1, False, 1e-2, "chunk_cross_entropy"),
],
)
def test_qwen3_dense_run(self, device, tp_size, compile, tol, loss_class):
self.create_pg(device)
hf_model = AutoModelForCausalLM.from_pretrained(
QWEN3_PATH,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
patch_hf_rms_norm(hf_model)
tokenizer = AutoTokenizer.from_pretrained(QWEN3_PATH)
input_ids = tokenizer("吃葡萄不吐葡萄皮", return_tensors="pt").input_ids.to("cuda")
with torch.no_grad():
output = hf_model(
input_ids=input_ids,
labels=input_ids.clone(),
)
expected_loss = output.loss
del hf_model
torch.cuda.empty_cache()
with torch.device("meta"):
cfg = Qwen3Dense8BConfig()
if not compile:
cfg.compile_cfg = False
qwen_model = cfg.build().to(torch.bfloat16)
shift_input_ids = input_ids[:, :-1]
shifted_labels = input_ids[:, 1:]
seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to('cuda'),))
loss_cfg = CELossConfig()
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
qwen_model.from_hf(QWEN3_PATH)
with torch.no_grad():
output = qwen_model(
seq_ctx=seq_ctx,
loss_ctx={"lm": loss_ctx},
)
loss = output["loss"]
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))
@parametrize.parametrize(
"device,tp_size",
[
("cuda", 1),
],
)
def test_fsdp_accuracy(self, device, tp_size):
self.create_pg(device)
hf_model = AutoModelForCausalLM.from_pretrained(
QWEN3_PATH,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
tokenizer = AutoTokenizer.from_pretrained(QWEN3_PATH, trust_remote_code=True)
input_ids = tokenizer("吃葡萄不吐葡萄皮", return_tensors="pt").input_ids.to("cuda")
with torch.no_grad():
output = hf_model(
input_ids=input_ids,
labels=input_ids.clone(),
)
expected_loss = output.loss
del hf_model
torch.cuda.empty_cache()
with torch.device("meta"):
cfg = Qwen3Dense8BConfig(compile_cfg=False)
qwen_model = cfg.build().to(torch.bfloat16)
fsdp_config = FSDPConfig(
tp_size=tp_size,
cpu_offload=False,
)
loss_cfg = CELossConfig()
shift_input_ids = input_ids[:, :-1]
shifted_labels = input_ids[:, 1:]
seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to('cuda'),))
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
qwen_model.fully_shard(fsdp_config=fsdp_config)
qwen_model.from_hf(QWEN3_PATH)
with torch.no_grad():
output = qwen_model(
seq_ctx=seq_ctx,
loss_ctx={"lm": loss_ctx},
)
loss = output["loss"]
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=1e-2, rtol=1e-2))
@parametrize.parametrize(
"use_sliding_window, max_window_layers, sliding_window",
[
(False, 6, 1024),
(True, 6, 1024),
(True, 4, 2048),
],
)
def test_sliding_windows(self, use_sliding_window, max_window_layers, sliding_window):
self.create_pg('cuda')
# test param
with torch.device("meta"):
num_hidden_layers = 6
attention = MHAConfig(num_attention_heads=32,
num_key_value_heads=8,
head_dim=128,
qk_norm=True,
sliding_window=sliding_window)
cfg = Qwen3Dense8BConfig(num_hidden_layers=num_hidden_layers,
use_sliding_window=use_sliding_window,
max_window_layers=max_window_layers,
attention=attention)
qwen_model = cfg.build().to(torch.bfloat16)
loss_cfg = CELossConfig()
if use_sliding_window is False or max_window_layers >= num_hidden_layers:
expected_sliding_window_size_list = [(-1, -1) for _ in range(num_hidden_layers)]
else:
expected_sliding_window_size_list = [(-1, -1) for _ in range(max_window_layers)]
expected_sliding_window_size_list += [(sliding_window, sliding_window) for _ in range(num_hidden_layers - max_window_layers)]
model_sliding_window_size_list = []
for layer in qwen_model.layers.values():
model_sliding_window_size_list.append(layer.self_attn.window_size)
self.assertListEqual(model_sliding_window_size_list, expected_sliding_window_size_list)
# test forward
if use_sliding_window is True:
with torch.device("meta"):
num_hidden_layers = 6
attention = MHAConfig(num_attention_heads=32,
num_key_value_heads=8,
head_dim=128,
qk_norm=True,
sliding_window=sliding_window)
cfg = Qwen3Dense8BConfig(num_hidden_layers=num_hidden_layers,
use_sliding_window=use_sliding_window,
max_window_layers=max_window_layers,
attention=attention)
qwen_model = cfg.build().to(torch.bfloat16)
fsdp_config = FSDPConfig()
tokenizer = AutoTokenizer.from_pretrained(QWEN3_PATH, trust_remote_code=True)
input_ids = tokenizer("吃葡萄不吐葡萄皮", return_tensors="pt").input_ids.to("cuda")
shift_input_ids = input_ids[:, :-1]
shifted_labels = input_ids[:, 1:]
seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to('cuda'),))
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
qwen_model.fully_shard(fsdp_config=fsdp_config)
qwen_model.from_hf(QWEN3_PATH, strict=False)
with torch.no_grad():
output = qwen_model(
seq_ctx=seq_ctx,
loss_ctx={"lm": loss_ctx},
)
assert "loss" in output
@parametrize.parametrize(
"device,tp_size",
[
("cuda", 1),
],
)
def test_save_hf(self, device, tp_size):
self.create_pg(device)
with torch.device("meta"):
cfg = Qwen3Dense8BConfig()
qwen_model = cfg.build().to(torch.bfloat16)
fsdp_config = FSDPConfig(
tp_size=tp_size,
cpu_offload=False,
)
cache_save_fh = {}
with tempfile.TemporaryDirectory() as tmpdir:
syncdir = [tmpdir]
dist.broadcast_object_list(syncdir, src=0)
tmpdir = Path(syncdir[0])
qwen_model.fully_shard(fsdp_config=fsdp_config)
qwen_model.from_hf(QWEN3_PATH)
qwen_model.save_hf(tmpdir)
origin_hf_path = Path(QWEN3_PATH)
origin_index_path = origin_hf_path / "model.safetensors.index.json"
saved_index_path = tmpdir / "model.safetensors.index.json"
# Test saved hf tensor value match the origin hf tensor value
if dist.get_rank() == 0:
with open(origin_index_path, "r") as f:
origin_index = json.load(f)
with open(saved_index_path, "r") as f:
saved_index = json.load(f)
for key in origin_index["weight_map"].keys():
origin_safetensor_name = origin_index["weight_map"][key]
saved_safetensor_name = saved_index["weight_map"][key]
origin_sf_fh_name = str(origin_hf_path / origin_safetensor_name)
expected_sf_fh_name = str(tmpdir / saved_safetensor_name)
if origin_safetensor_name not in cache_save_fh:
cache_save_fh[origin_safetensor_name] = safe_open(origin_sf_fh_name, framework="pt")
if saved_safetensor_name not in cache_save_fh:
cache_save_fh[saved_safetensor_name] = safe_open(expected_sf_fh_name, framework="pt")
origin_fh = cache_save_fh[origin_safetensor_name]
saved_fh = cache_save_fh[saved_safetensor_name]
origin_tensor = origin_fh.get_tensor(key)
saved_tensor = saved_fh.get_tensor(key)
self.assertTrue(torch.equal(origin_tensor, saved_tensor))
# Test the tensor number in safetensors match the tensor number in model index
safetensor_keys = []
for safetensor_path in tmpdir.glob("*.safetensors"):
fh = cache_save_fh[safetensor_path.name]
safetensor_keys.extend(fh.keys())
safetensor_keys.sort()
model_index_keys = list(saved_index["weight_map"].keys())
model_index_keys.sort()
self.assertListEqual(safetensor_keys, model_index_keys)
dist.barrier()
@property
def world_size(self) -> int:
return int(os.getenv("XTUNER_TEST_WORLD_SIZE", "8"))