Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/megatron/bridge/models/deepseek/deepseek_v4_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@
}


def deepseek_v4_supports_blackwell_fused_kernels() -> bool:
"""Return whether DSv4 Blackwell-only fused kernels should default on."""
if not torch.cuda.is_available():
return True

major, _minor = torch.cuda.get_device_capability()
return major >= 10


def set_deepseek_v4_pipeline_model_parallel_layout(model_cfg: MLAModelProvider) -> None:
"""Set an even DSv4 pipeline layout with MTP and loss on the last stage.

Expand Down Expand Up @@ -370,6 +379,7 @@ def generate_pipeline_layout(num_layers: int, pp: int, mtp_layers: int = 1) -> l
def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> MLAModelProvider:
provider = super().provider_bridge(hf_pretrained)
hf_config = hf_pretrained.config
use_blackwell_fused_kernels = deepseek_v4_supports_blackwell_fused_kernels()

# ---- Attention ----
provider.experimental_attention_variant = "dsv4_hybrid"
Expand Down Expand Up @@ -447,10 +457,11 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> MLAModelProvider
provider.dsa_indexer_n_heads = hf_config.index_n_heads # 64
provider.dsa_indexer_head_dim = hf_config.index_head_dim # 128
provider.dsa_indexer_topk = hf_config.index_topk # 512
provider.apply_dsa_kernel_fusion = use_blackwell_fused_kernels

# ---- Hyper-Connections (mHC) ----
provider.enable_hyper_connections = True
provider.use_fused_mhc = True
provider.use_fused_mhc = use_blackwell_fused_kernels
provider.num_residual_streams = hf_config.hc_mult # 4
provider.mhc_sinkhorn_iterations = hf_config.hc_sinkhorn_iters # 20

Expand Down
45 changes: 24 additions & 21 deletions src/megatron/bridge/recipes/deepseek/deepseek_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from megatron.core.quantization.quant_config import RecipeConfig

from megatron.bridge import AutoBridge
from megatron.bridge.models.deepseek.deepseek_v4_bridge import set_deepseek_v4_pipeline_model_parallel_layout
from megatron.bridge.models.deepseek.deepseek_v4_bridge import (
deepseek_v4_supports_blackwell_fused_kernels,
set_deepseek_v4_pipeline_model_parallel_layout,
)
from megatron.bridge.recipes.common import _pretrain_common, _sft_common
from megatron.bridge.recipes.utils.finetune_utils import default_squad_config
from megatron.bridge.recipes.utils.optimizer_utils import (
Expand Down Expand Up @@ -56,7 +59,7 @@ def deepseek_v4_flash_pretrain_config() -> ConfigContainer:

Recommended Blackwell baseline: TP=1, PP=4, EP=8, CP=1.
"""
use_fused_kernels = True
use_fused_mhc = deepseek_v4_supports_blackwell_fused_kernels()
cfg = _pretrain_common()
cfg.model = AutoBridge.from_hf_pretrained(
"deepseek-ai/DeepSeek-V4-Flash", trust_remote_code=True
Expand All @@ -82,8 +85,8 @@ def deepseek_v4_flash_pretrain_config() -> ConfigContainer:
cfg.model.transformer_impl = "transformer_engine"
cfg.model.attention_backend = None
cfg.model.apply_dsa_kernel_fusion = False
cfg.model.apply_rope_fusion = use_fused_kernels
cfg.model.use_fused_mhc = use_fused_kernels
cfg.model.apply_rope_fusion = True
cfg.model.use_fused_mhc = use_fused_mhc
cfg.model.dsa_indexer_loss_coeff = 0.0
cfg.model.dsa_indexer_use_sparse_loss = False

Expand Down Expand Up @@ -156,10 +159,10 @@ def deepseek_v4_flash_pretrain_mxfp8_config() -> ConfigContainer:
cfg.train.train_iters = 1_000_000
cfg.train.global_batch_size = 128
cfg.train.micro_batch_size = 1
use_fused_kernels = True
use_fused_mhc = deepseek_v4_supports_blackwell_fused_kernels()
cfg.model.apply_dsa_kernel_fusion = False
cfg.model.apply_rope_fusion = use_fused_kernels
cfg.model.use_fused_mhc = use_fused_kernels
cfg.model.apply_rope_fusion = True
cfg.model.use_fused_mhc = use_fused_mhc
cfg.model.dsa_indexer_loss_coeff = 0.0
cfg.model.dsa_indexer_use_sparse_loss = False
cfg.model.moe_token_dispatcher_type = "alltoall"
Expand Down Expand Up @@ -224,10 +227,10 @@ def deepseek_v4_flash_pretrain_muon_config() -> ConfigContainer:
cfg.train.train_iters = 1_000_000
cfg.train.global_batch_size = 128
cfg.train.micro_batch_size = 1
use_fused_kernels = True
use_fused_mhc = deepseek_v4_supports_blackwell_fused_kernels()
cfg.model.apply_dsa_kernel_fusion = False
cfg.model.apply_rope_fusion = use_fused_kernels
cfg.model.use_fused_mhc = use_fused_kernels
cfg.model.apply_rope_fusion = True
cfg.model.use_fused_mhc = use_fused_mhc
cfg.model.dsa_indexer_loss_coeff = 0.0
cfg.model.dsa_indexer_use_sparse_loss = False
cfg.model.moe_token_dispatcher_type = "alltoall"
Expand Down Expand Up @@ -282,10 +285,11 @@ def deepseek_v4_flash_pretrain_muon_config() -> ConfigContainer:


def deepseek_v4_flash_sft_config(hf_path: str = DEEPSEEK_V4_FLASH_HF_PATH) -> ConfigContainer:
"""DeepSeek-V4-Flash full SFT, MTP enabled, Hopper-safe (unfused mHC, bf16).
"""DeepSeek-V4-Flash full SFT, MTP enabled, Hopper-safe.

Runs unchanged on Hopper (H100/H200) and Blackwell (B200/GB200). Full
parameter training on unpacked (SBHD) sequences with Adam/bf16. Set
Runs unchanged on Hopper (H100/H200) and Blackwell (B200/GB200). Fused mHC
is enabled only on Blackwell. Full parameter training on unpacked (SBHD)
sequences with Adam/bf16. Set
``checkpoint.pretrained_checkpoint`` to the imported Megatron checkpoint to
fine-tune real weights; ``hf_path`` overrides the HF model id (e.g. a toy
model in tests).
Expand All @@ -305,12 +309,12 @@ def deepseek_v4_flash_sft_config(hf_path: str = DEEPSEEK_V4_FLASH_HF_PATH) -> Co
cfg.model.params_dtype = torch.bfloat16
cfg.model.seq_length = 4096

# --- attention / kernels: fused mHC + fused rope (Blackwell-verified), unfused DSA ---
# --- attention / kernels: fused mHC on Blackwell, unfused mHC on Hopper, unfused DSA ---
cfg.model.transformer_impl = "transformer_engine"
cfg.model.attention_backend = None
cfg.model.apply_dsa_kernel_fusion = False
cfg.model.apply_rope_fusion = True
cfg.model.use_fused_mhc = True
cfg.model.use_fused_mhc = deepseek_v4_supports_blackwell_fused_kernels()
cfg.model.dsa_indexer_loss_coeff = 0.0
cfg.model.dsa_indexer_use_sparse_loss = False

Expand Down Expand Up @@ -352,7 +356,7 @@ def deepseek_v4_flash_no_mtp_sft_config(hf_path: str = DEEPSEEK_V4_FLASH_HF_PATH
"""DeepSeek-V4-Flash full SFT with the MTP layer disabled, Hopper-safe.

Same as :func:`deepseek_v4_flash_sft_config` but drops the Multi-Token
Prediction layer (unfused mHC, bf16, SBHD; runs on Hopper and Blackwell).
Prediction layer (fused mHC only on Blackwell, bf16, SBHD).
"""
cfg = _sft_common()
cfg.model = AutoBridge.from_hf_pretrained(hf_path, trust_remote_code=True).to_megatron_provider(load_weights=False)
Expand All @@ -369,12 +373,12 @@ def deepseek_v4_flash_no_mtp_sft_config(hf_path: str = DEEPSEEK_V4_FLASH_HF_PATH
cfg.model.params_dtype = torch.bfloat16
cfg.model.seq_length = 4096

# --- attention / kernels: fused mHC + fused rope (Blackwell-verified), unfused DSA ---
# --- attention / kernels: fused mHC on Blackwell, unfused mHC on Hopper, unfused DSA ---
cfg.model.transformer_impl = "transformer_engine"
cfg.model.attention_backend = None
cfg.model.apply_dsa_kernel_fusion = False
cfg.model.apply_rope_fusion = True
cfg.model.use_fused_mhc = True
cfg.model.use_fused_mhc = deepseek_v4_supports_blackwell_fused_kernels()
cfg.model.dsa_indexer_loss_coeff = 0.0
cfg.model.dsa_indexer_use_sparse_loss = False

Expand Down Expand Up @@ -419,11 +423,10 @@ def deepseek_v4_flash_no_mtp_sft_config(hf_path: str = DEEPSEEK_V4_FLASH_HF_PATH
return cfg


# NOTE: the SFT recipes enable fused mHC and fused rope, matching the pretrain recipes.
# NOTE: the SFT recipes enable fused mHC on Blackwell and fused rope on all supported GPUs.
# The historical "fused-kernel SFT NaN" reports are both resolved: fused mHC was a confound,
# and the fused-rope NaN was a bridge config-mapping bug fixed by rotary_percent=1.0 (#4271);
# with that fix, full-model SFT with rope fusion matches the unfused control. The fused mHC
# cuTile kernel is sm_100 (Blackwell); on Hopper set use_fused_mhc=False.
# with that fix, full-model SFT with rope fusion matches the unfused control.
#
# NOTE: there are intentionally no MXFP8 or Muon *SFT* variants either. Both were prototyped
# (mirroring the pretrain recipes) but fail in full-model DSv4-Flash SFT — MXFP8 NaNs at iter-2
Expand Down
112 changes: 77 additions & 35 deletions tests/unit_tests/models/deepseek/test_deepseek_v4_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
"""

from types import SimpleNamespace
from unittest.mock import MagicMock, patch

import pytest
import torch

from megatron.bridge.models.conversion import quantization_utils
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
from megatron.bridge.models.conversion.param_mapping import AutoMapping, ReplicatedMapping
from megatron.bridge.models.deepseek.deepseek_v4_bridge import (
DeepSeekV4Bridge,
Expand Down Expand Up @@ -58,6 +60,38 @@ def _dummy_task():
return SimpleNamespace(param_name="", global_param_name="", mapping=None)


def _deepseek_v4_hf_config():
return SimpleNamespace(
head_dim=512,
qk_rope_head_dim=64,
q_lora_rank=1024,
o_groups=8,
o_lora_rank=1024,
rope_theta=10000,
compress_rope_theta=160000,
rope_scaling={"factor": 16, "original_max_position_embeddings": 65536},
num_hidden_layers=4,
num_nextn_predict_layers=1,
num_hash_layers=3,
compress_ratios=[0, 4, 128, 4, 0],
sliding_window=128,
index_n_heads=64,
index_head_dim=128,
index_topk=512,
hc_mult=4,
hc_sinkhorn_iters=20,
scoring_func="sqrtsoftplus",
num_experts_per_tok=6,
norm_topk_prob=True,
routed_scaling_factor=1.5,
vocab_size=129280,
swiglu_limit=10.0,
moe_intermediate_size=1024,
n_shared_experts=1,
tie_word_embeddings=False,
)


class TestNativeDeepSeekV4ConfigTranslation:
"""Native Transformers DSv4 config fields must map back to MCore fields."""

Expand Down Expand Up @@ -349,42 +383,8 @@ class TestDeepSeekV4RotaryPercent:
rotates 8/64 dims and the fused MLA rope kernel reads cos/sin out of bounds (SFT NaN)."""

def test_provider_bridge_forces_full_rotary_percent(self):
from unittest.mock import MagicMock, patch

from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
from megatron.bridge.models.deepseek.deepseek_v4_bridge import DeepSeekV4Bridge

hf_config = SimpleNamespace(
head_dim=512,
qk_rope_head_dim=64,
q_lora_rank=1024,
o_groups=8,
o_lora_rank=1024,
rope_theta=10000,
compress_rope_theta=160000,
rope_scaling={"factor": 16, "original_max_position_embeddings": 65536},
num_hidden_layers=4,
num_nextn_predict_layers=1,
num_hash_layers=3,
compress_ratios=[0, 4, 128, 4, 0],
sliding_window=128,
index_n_heads=64,
index_head_dim=128,
index_topk=512,
hc_mult=4,
hc_sinkhorn_iters=20,
scoring_func="sqrtsoftplus",
num_experts_per_tok=6,
norm_topk_prob=True,
routed_scaling_factor=1.5,
vocab_size=129280,
swiglu_limit=10.0,
moe_intermediate_size=1024,
n_shared_experts=1,
tie_word_embeddings=False,
)
hf_pretrained = MagicMock()
hf_pretrained.config = hf_config
hf_pretrained.config = _deepseek_v4_hf_config()
provider = MagicMock()
# what the generic partial_rotary_factor -> rotary_percent mapping produces
provider.rotary_percent = 0.125
Expand All @@ -394,3 +394,45 @@ def test_provider_bridge_forces_full_rotary_percent(self):
out = bridge.provider_bridge(hf_pretrained)

assert out.rotary_percent == 1.0


class TestDeepSeekV4HardwareDefaults:
"""DSv4 Blackwell-only fused kernels must not default on for Hopper."""

@pytest.mark.parametrize(
("capability", "expected"),
[
((9, 0), False),
((10, 0), True),
],
)
def test_provider_bridge_gates_blackwell_only_fusions(self, capability, expected):
hf_pretrained = MagicMock()
hf_pretrained.config = _deepseek_v4_hf_config()
provider = MagicMock()

bridge = DeepSeekV4Bridge.__new__(DeepSeekV4Bridge)
with (
patch.object(MegatronModelBridge, "provider_bridge", return_value=provider),
patch.object(torch.cuda, "is_available", return_value=True),
patch.object(torch.cuda, "get_device_capability", return_value=capability),
):
out = bridge.provider_bridge(hf_pretrained)

assert out.apply_dsa_kernel_fusion is expected
assert out.use_fused_mhc is expected

def test_provider_bridge_preserves_fused_defaults_without_cuda(self):
hf_pretrained = MagicMock()
hf_pretrained.config = _deepseek_v4_hf_config()
provider = MagicMock()

bridge = DeepSeekV4Bridge.__new__(DeepSeekV4Bridge)
with (
patch.object(MegatronModelBridge, "provider_bridge", return_value=provider),
patch.object(torch.cuda, "is_available", return_value=False),
):
out = bridge.provider_bridge(hf_pretrained)

assert out.apply_dsa_kernel_fusion is True
assert out.use_fused_mhc is True
25 changes: 25 additions & 0 deletions tests/unit_tests/recipes/test_deepseek_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def test_deepseek_v3_pipeline_layout_keeps_default_mtp_with_loss():
def _build_deepseek_v4_recipe(name: str, monkeypatch: pytest.MonkeyPatch):
mod = importlib.import_module("megatron.bridge.recipes.deepseek.deepseek_v4")
monkeypatch.setattr(mod, "AutoBridge", _FakeBridge)
monkeypatch.setattr(mod, "deepseek_v4_supports_blackwell_fused_kernels", lambda: True)
return getattr(mod, name)()


Expand Down Expand Up @@ -391,6 +392,30 @@ def test_deepseek_v4_base_recipe_uses_blackwell_defaults(monkeypatch: pytest.Mon
assert cfg.train.micro_batch_size == 1


@pytest.mark.parametrize(
"recipe_name",
[
"deepseek_v4_flash_pretrain_config",
"deepseek_v4_flash_pretrain_mxfp8_config",
"deepseek_v4_flash_pretrain_muon_config",
"deepseek_v4_flash_sft_config",
"deepseek_v4_flash_no_mtp_sft_config",
],
)
def test_deepseek_v4_recipes_disable_blackwell_only_fusions_when_unavailable(
recipe_name: str, monkeypatch: pytest.MonkeyPatch
):
mod = importlib.import_module("megatron.bridge.recipes.deepseek.deepseek_v4")
monkeypatch.setattr(mod, "AutoBridge", _FakeBridge)
monkeypatch.setattr(mod, "deepseek_v4_supports_blackwell_fused_kernels", lambda: False)

cfg = getattr(mod, recipe_name)()

assert cfg.model.apply_dsa_kernel_fusion is False
assert cfg.model.apply_rope_fusion is True
assert cfg.model.use_fused_mhc is False


def test_deepseek_v4_flash_sft_recipe_uses_fused_mhc(monkeypatch: pytest.MonkeyPatch):
cfg = _build_deepseek_v4_recipe("deepseek_v4_flash_sft_config", monkeypatch)

Expand Down
Loading