-
Notifications
You must be signed in to change notification settings - Fork 388
Expand file tree
/
Copy pathtest_deepseek_v4_bridge.py
More file actions
438 lines (363 loc) · 16.4 KB
/
Copy pathtest_deepseek_v4_bridge.py
File metadata and controls
438 lines (363 loc) · 16.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the DeepSeek-V4 bridge mapping registry.
Locks in the MTP mapping layout: per-MTP-layer HC head, separate ``e_proj``
and ``h_proj`` mappings, and no deprecated concatenated ``eh_proj`` path.
"""
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
import torch
from megatron.bridge.models.conversion import quantization_utils
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
from megatron.bridge.models.conversion.param_mapping import AutoMapping, ReplicatedMapping
from megatron.bridge.models.deepseek.deepseek_v4_bridge import (
DeepSeekV4Bridge,
_dsv4_compress_ratios,
_dsv4_num_hash_layers,
)
@pytest.fixture
def bridge_with_mtp():
"""A DSv4 bridge with hf_config stubbed for a single MTP layer."""
bridge = DeepSeekV4Bridge()
# mapping_registry only reads num_nextn_predict_layers from hf_config.
bridge.hf_config = SimpleNamespace(num_nextn_predict_layers=1)
return bridge
@pytest.fixture
def bridge_without_mtp():
"""A DSv4 bridge with hf_config that has zero MTP layers."""
bridge = DeepSeekV4Bridge()
bridge.hf_config = SimpleNamespace(num_nextn_predict_layers=0)
return bridge
def _by_megatron(registry):
"""Index mappings by megatron_param for quick lookup in assertions."""
return {m.megatron_param: m for m in registry.mappings}
def _dummy_task():
return SimpleNamespace(param_name="", global_param_name="", mapping=None)
def _deepseek_v4_hf_config():
return SimpleNamespace(
head_dim=512,
qk_rope_head_dim=64,
q_lora_rank=1024,
o_groups=8,
o_lora_rank=1024,
rope_theta=10000,
compress_rope_theta=160000,
rope_scaling={"factor": 16, "original_max_position_embeddings": 65536},
num_hidden_layers=4,
num_nextn_predict_layers=1,
num_hash_layers=3,
compress_ratios=[0, 4, 128, 4, 0],
sliding_window=128,
index_n_heads=64,
index_head_dim=128,
index_topk=512,
hc_mult=4,
hc_sinkhorn_iters=20,
scoring_func="sqrtsoftplus",
num_experts_per_tok=6,
norm_topk_prob=True,
routed_scaling_factor=1.5,
vocab_size=129280,
swiglu_limit=10.0,
moe_intermediate_size=1024,
n_shared_experts=1,
tie_word_embeddings=False,
)
class TestNativeDeepSeekV4ConfigTranslation:
"""Native Transformers DSv4 config fields must map back to MCore fields."""
def test_compress_ratios_from_native_layer_types(self):
hf_config = SimpleNamespace(
num_hidden_layers=4,
num_nextn_predict_layers=1,
layer_types=[
"sliding_attention",
"sliding_attention",
"compressed_sparse_attention",
"heavily_compressed_attention",
],
compress_rates={
"compressed_sparse_attention": 4,
"heavily_compressed_attention": 128,
},
)
assert _dsv4_compress_ratios(hf_config) == [0, 0, 4, 128, 0]
def test_legacy_compress_ratios_still_work(self):
hf_config = SimpleNamespace(
num_hidden_layers=4,
num_nextn_predict_layers=1,
compress_ratios=[0, 0, 4, 128, 0],
)
assert _dsv4_compress_ratios(hf_config) == [0, 0, 4, 128, 0]
def test_hash_layers_from_native_mlp_layer_types(self):
hf_config = SimpleNamespace(
mlp_layer_types=["hash_moe", "hash_moe", "hash_moe", "moe", "moe"],
)
assert _dsv4_num_hash_layers(hf_config) == 3
def test_hash_layers_must_be_prefix(self):
hf_config = SimpleNamespace(mlp_layer_types=["hash_moe", "moe", "hash_moe"])
with pytest.raises(ValueError, match="contiguous prefix"):
_dsv4_num_hash_layers(hf_config)
class TestDeepSeekV4QuantizedExport:
"""DSv4 export must regenerate quantized weights and scale tensors."""
def test_export_quantizes_fp8_weight_and_emits_scale(self):
bridge = DeepSeekV4Bridge()
hf_param = "layers.0.attn.wq_a.weight"
scale_key = "layers.0.attn.wq_a.scale"
weight = torch.full((4, 4), 2.0, dtype=torch.bfloat16)
source_state = {scale_key: torch.ones((1, 1), dtype=torch.float32)}
result = bridge.maybe_modify_converted_hf_weight(_dummy_task(), {hf_param: weight}, source_state)
assert set(result) == {hf_param, scale_key}
assert result[hf_param].dtype == torch.float8_e4m3fn
assert result[scale_key].shape == source_state[scale_key].shape
assert result[scale_key].dtype == source_state[scale_key].dtype
restored = bridge.maybe_modify_loaded_hf_weight(hf_param, result)
assert restored.dtype == torch.bfloat16
assert torch.allclose(restored.float(), weight.float())
def test_export_preserves_e8m0_scale_dtype(self):
e8m0_dtype = getattr(torch, "float8_e8m0fnu", None)
if e8m0_dtype is None:
pytest.skip("torch.float8_e8m0fnu is unavailable")
try:
source_scale = torch.ones((1, 1), dtype=e8m0_dtype)
except RuntimeError as exc:
pytest.skip(f"torch.float8_e8m0fnu tensor creation is unavailable: {exc}")
bridge = DeepSeekV4Bridge()
hf_param = "layers.0.attn.wq_a.weight"
scale_key = "layers.0.attn.wq_a.scale"
weight = torch.full((4, 4), 2.0, dtype=torch.bfloat16)
result = bridge.maybe_modify_converted_hf_weight(_dummy_task(), {hf_param: weight}, {scale_key: source_scale})
assert result[hf_param].dtype == torch.float8_e4m3fn
assert result[scale_key].dtype == e8m0_dtype
restored = bridge.maybe_modify_loaded_hf_weight(hf_param, result)
assert torch.allclose(restored.float(), weight.float())
def test_export_quantizes_routed_expert_to_mxfp4_and_emits_scale(self):
bridge = DeepSeekV4Bridge()
hf_param = "layers.0.ffn.experts.0.w1.weight"
scale_key = "layers.0.ffn.experts.0.w1.scale"
values = torch.tensor(
[
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
],
dtype=torch.float32,
).repeat(2)
weight = values.reshape(1, 32).to(torch.bfloat16)
source_state = {scale_key: torch.ones((1, 1), dtype=torch.float32)}
result = bridge.maybe_modify_converted_hf_weight(_dummy_task(), {hf_param: weight}, source_state)
assert set(result) == {hf_param, scale_key}
assert result[hf_param].dtype == torch.int8
assert result[hf_param].shape == (1, 16)
assert result[scale_key].shape == source_state[scale_key].shape
assert result[scale_key].dtype == source_state[scale_key].dtype
restored = quantization_utils.dequantize_mxfp4_e2m1_packed(result[hf_param], result[scale_key])
assert torch.equal(restored.float(), weight.float())
@pytest.mark.parametrize(
"hf_param",
[
"layers.0.ffn.shared_experts.w1.weight",
"layers.0.ffn.experts.0.w1.weight",
],
)
def test_export_uses_fp8_for_non_mxfp4_expert_scale_geometry(self, hf_param):
bridge = DeepSeekV4Bridge()
scale_key = hf_param.removesuffix(".weight") + ".scale"
weight = torch.full((4, 4), 2.0, dtype=torch.bfloat16)
result = bridge.maybe_modify_converted_hf_weight(
_dummy_task(), {hf_param: weight}, {scale_key: torch.ones(1, 1)}
)
assert result[hf_param].dtype == torch.float8_e4m3fn
assert result[scale_key].shape == (1, 1)
def test_export_leaves_unscaled_weight_unchanged(self):
bridge = DeepSeekV4Bridge()
weight = torch.ones(4, 4, dtype=torch.bfloat16)
result = bridge.maybe_modify_converted_hf_weight(_dummy_task(), {"norm.weight": weight}, {})
assert set(result) == {"norm.weight"}
assert result["norm.weight"] is weight
def test_export_roundtrips_mixed_quantized_hf_state(self):
bridge = DeepSeekV4Bridge()
fp8_param = "layers.0.attn.wq_a.weight"
fp8_scale = "layers.0.attn.wq_a.scale"
mxfp4_param = "layers.0.ffn.experts.0.w1.weight"
mxfp4_scale = "layers.0.ffn.experts.0.w1.scale"
norm_param = "layers.0.attn_norm.weight"
fp8_weight = torch.full((4, 4), 2.0, dtype=torch.bfloat16)
mxfp4_values = torch.tensor(
[
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
],
dtype=torch.float32,
).repeat(2)
mxfp4_weight = mxfp4_values.reshape(1, 32).to(torch.bfloat16)
norm_weight = torch.arange(4, dtype=torch.float32).to(torch.bfloat16)
stale_scale = torch.full((1, 1), 9.0, dtype=torch.float32)
result = bridge.maybe_modify_converted_hf_weight(
_dummy_task(),
{
fp8_param: fp8_weight,
fp8_scale: stale_scale,
mxfp4_param: mxfp4_weight,
mxfp4_scale: stale_scale,
norm_param: norm_weight,
},
{
fp8_scale: torch.ones((1, 1), dtype=torch.float32),
mxfp4_scale: torch.ones((1, 1), dtype=torch.float32),
},
)
assert set(result) == {fp8_param, fp8_scale, mxfp4_param, mxfp4_scale, norm_param}
assert result[fp8_param].dtype == torch.float8_e4m3fn
assert result[mxfp4_param].dtype == torch.int8
assert result[mxfp4_param].shape == (1, 16)
assert result[fp8_scale].shape == (1, 1)
assert result[mxfp4_scale].shape == (1, 1)
assert not torch.equal(result[fp8_scale], stale_scale)
assert not torch.equal(result[mxfp4_scale], stale_scale)
assert result[norm_param] is norm_weight
restored_fp8 = bridge.maybe_modify_loaded_hf_weight(fp8_param, result)
restored_mxfp4 = bridge.maybe_modify_loaded_hf_weight(mxfp4_param, result)
assert torch.allclose(restored_fp8.float(), fp8_weight.float())
assert torch.equal(restored_mxfp4.float(), mxfp4_weight.float())
class TestDecoderHCHeadMappings:
"""The global decoder HC-head triplet must be replicated mappings."""
@pytest.mark.parametrize(
"name",
["decoder.hc_head_fn", "decoder.hc_head_base", "decoder.hc_head_scale"],
)
def test_decoder_hc_head_replicated(self, bridge_with_mtp, name):
registry = bridge_with_mtp.mapping_registry()
mapping = _by_megatron(registry).get(name)
assert mapping is not None, f"missing decoder HC-head mapping: {name}"
assert isinstance(mapping, ReplicatedMapping)
# HF side drops the 'decoder.' prefix.
assert mapping.hf_param == name.removeprefix("decoder.")
class TestMTPHCHeadMappings:
"""Per-MTP-layer HC head must mirror the decoder pattern."""
@pytest.mark.parametrize(
"suffix",
["hc_head_fn", "hc_head_base", "hc_head_scale"],
)
def test_mtp_hc_head_replicated(self, bridge_with_mtp, suffix):
registry = bridge_with_mtp.mapping_registry()
mapping = _by_megatron(registry).get(f"mtp.layers.0.{suffix}")
assert mapping is not None, f"missing MTP HC-head mapping: mtp.layers.0.{suffix}"
assert isinstance(mapping, ReplicatedMapping)
assert mapping.hf_param == f"mtp.0.{suffix}"
def test_mtp_hc_head_absent_when_no_mtp(self, bridge_without_mtp):
registry = bridge_without_mtp.mapping_registry()
names = _by_megatron(registry)
for suffix in ("hc_head_fn", "hc_head_base", "hc_head_scale"):
assert f"mtp.layers.0.{suffix}" not in names
class TestMTPEHProjSplit:
"""MTP e_proj and h_proj are separate ColumnParallelLinear modules.
The bridge must use two AutoMappings (which auto-detect column parallelism),
not the deprecated concatenated eh_proj path.
"""
@pytest.mark.parametrize("name", ["e_proj", "h_proj"])
def test_split_proj_automapping(self, bridge_with_mtp, name):
registry = bridge_with_mtp.mapping_registry()
mapping = _by_megatron(registry).get(f"mtp.layers.0.{name}.weight")
assert mapping is not None, f"missing MTP projection: {name}"
assert isinstance(mapping, AutoMapping)
assert mapping.hf_param == f"mtp.0.{name}.weight"
def test_eh_proj_not_in_registry(self, bridge_with_mtp):
registry = bridge_with_mtp.mapping_registry()
for mapping in registry.mappings:
assert "eh_proj" not in mapping.megatron_param, (
f"deprecated eh_proj reference found in megatron_param: {mapping.megatron_param}"
)
hf_param = mapping.hf_param
if isinstance(hf_param, str):
assert "eh_proj" not in hf_param, f"deprecated eh_proj reference found in hf_param: {hf_param}"
elif isinstance(hf_param, dict):
for v in hf_param.values():
assert "eh_proj" not in v, f"deprecated eh_proj reference found in hf_param dict value: {v}"
class TestDeepSeekV4RotaryPercent:
"""Regression: HF partial_rotary_factor (relative to head_dim=512) must not shrink
the Megatron rope cache — qk_pos_emb_head_dim (64) already encodes the rope split.
rotary_percent=0.125 yields an 8-dim cos/sin cache: the unfused path silently
rotates 8/64 dims and the fused MLA rope kernel reads cos/sin out of bounds (SFT NaN)."""
def test_provider_bridge_forces_full_rotary_percent(self):
hf_pretrained = MagicMock()
hf_pretrained.config = _deepseek_v4_hf_config()
provider = MagicMock()
# what the generic partial_rotary_factor -> rotary_percent mapping produces
provider.rotary_percent = 0.125
bridge = DeepSeekV4Bridge.__new__(DeepSeekV4Bridge)
with patch.object(MegatronModelBridge, "provider_bridge", return_value=provider):
out = bridge.provider_bridge(hf_pretrained)
assert out.rotary_percent == 1.0
class TestDeepSeekV4HardwareDefaults:
"""DSv4 Blackwell-only fused kernels must not default on for Hopper."""
@pytest.mark.parametrize(
("capability", "expected"),
[
((9, 0), False),
((10, 0), True),
],
)
def test_provider_bridge_gates_blackwell_only_fusions(self, capability, expected):
hf_pretrained = MagicMock()
hf_pretrained.config = _deepseek_v4_hf_config()
provider = MagicMock()
bridge = DeepSeekV4Bridge.__new__(DeepSeekV4Bridge)
with (
patch.object(MegatronModelBridge, "provider_bridge", return_value=provider),
patch.object(torch.cuda, "is_available", return_value=True),
patch.object(torch.cuda, "get_device_capability", return_value=capability),
):
out = bridge.provider_bridge(hf_pretrained)
assert out.apply_dsa_kernel_fusion is expected
assert out.use_fused_mhc is expected
def test_provider_bridge_preserves_fused_defaults_without_cuda(self):
hf_pretrained = MagicMock()
hf_pretrained.config = _deepseek_v4_hf_config()
provider = MagicMock()
bridge = DeepSeekV4Bridge.__new__(DeepSeekV4Bridge)
with (
patch.object(MegatronModelBridge, "provider_bridge", return_value=provider),
patch.object(torch.cuda, "is_available", return_value=False),
):
out = bridge.provider_bridge(hf_pretrained)
assert out.apply_dsa_kernel_fusion is True
assert out.use_fused_mhc is True