Skip to content

Commit 82ea671

Browse files
zorrofoxTest User
andauthored
feat(models): MiMoV2MTPForCausalLM (V2.5-Pro NextN/MTP draft model) (#1065)
Part of #1053 (P1-3). Single-layer draft model for MiMo-V2.5-Pro 3-layer MTP; one instance loads exactly one layer from model_mtp.safetensors (selected via config.mtp_layer_idx). MultiLayerDraftWorker (#1053 P1-4) will create one instance per layer. Architecture (mirrors sglang GPU mimo_v2_nextn.py): - MiMoV2NextNDecoderLayer: SWA MiMoV2Attention + dense MiMoV2MLP (the MTP block is NOT MoE, unlike the target's 384-expert layers) - MiMoV2ModelNextN: embed + enorm/hnorm + eh_proj(2h→h) + 1 block + final_ln - MiMoV2MTPForCausalLM: load_lm_head_from_target=True; load_weights handles the V2.5-Pro fused-QKV per-shard FP8 layout via loader.dequant_fused_qkv + dequant_fp8_layers for the dense MLP Weight mapping verified against the live model_mtp.safetensors header (48 tensors = 16 per layer × 3): test_mimo_v2_nextn.py asserts exact key coverage, per-layer-idx selection, and unique target paths (3/3 passed). E2E load+forward numerical verification deferred to P1-4 (the MultiLayerDraftWorker E2E acceptance "bs=1 accept-len ≥2.5" implies per-layer logits correctness). Co-authored-by: Test User <test@example.com>
1 parent 8c78596 commit 82ea671

4 files changed

Lines changed: 397 additions & 0 deletions

File tree

python/sgl_jax/srt/configs/model_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ def __init__(
141141

142142
if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
143143
self.hf_config.architectures[0] = "MiMoMTPForCausalLM"
144+
145+
if is_draft_model and self.hf_config.architectures[0] == "MiMoV2ForCausalLM":
146+
self.hf_config.architectures[0] = "MiMoV2MTPForCausalLM"
144147
# Check model type
145148
self.is_generation = is_generation_model(self.hf_config.architectures, is_embedding)
146149
self.is_multimodal = False
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
"""MiMo-V2.5-Pro multi-token-prediction (NextN/MTP) draft model.
2+
3+
One instance loads exactly one of the 3 MTP layers from
4+
``model_mtp.safetensors`` (selected via ``config.mtp_layer_idx``); the
5+
``MultiLayerDraftWorker`` (#1053 P1-4) creates one instance per layer. The
6+
MTP block is a SWA-attention + dense-MLP decoder (not MoE), with the same
7+
fused-QKV per-shard FP8 layout as the V2.5-Pro target.
8+
"""
9+
10+
import logging
11+
from types import SimpleNamespace
12+
13+
import jax
14+
import jax.numpy as jnp
15+
from flax import nnx
16+
from transformers import PretrainedConfig
17+
18+
from sgl_jax.srt.configs.model_config import ModelConfig
19+
from sgl_jax.srt.layers.embeddings import Embed, ParallelLMHead
20+
from sgl_jax.srt.layers.layernorm import RMSNorm
21+
from sgl_jax.srt.layers.linear import LinearBase
22+
from sgl_jax.srt.layers.logits_processor import LogitsMetadata, LogitsProcessor
23+
from sgl_jax.srt.mem_cache.memory_pool import KVCache, MemoryPools
24+
from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch
25+
from sgl_jax.srt.models.mimo_v2_flash import MiMoV2Attention, MiMoV2MLP
26+
from sgl_jax.srt.utils.weight_utils import WeightLoader, WeightMapping
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
class MiMoV2NextNDecoderLayer(nnx.Module):
32+
"""Single MTP decoder block: SWA attention + dense MLP (never MoE)."""
33+
34+
def __init__(
35+
self,
36+
config: PretrainedConfig,
37+
mesh: jax.sharding.Mesh,
38+
layer_id: int = 0,
39+
dtype: jnp.dtype = jnp.bfloat16,
40+
):
41+
self.layer_id = layer_id
42+
rope_theta = getattr(config, "rope_theta", 1000000)
43+
rope_scaling = getattr(config, "rope_scaling", None)
44+
if isinstance(rope_scaling, dict) and rope_scaling.get("rope_type") == "default":
45+
rope_scaling = None
46+
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
47+
48+
self.self_attn = MiMoV2Attention(
49+
hidden_size=config.hidden_size,
50+
num_heads=config.swa_num_attention_heads,
51+
num_kv_heads=config.swa_num_key_value_heads,
52+
max_position_embeddings=max_position_embeddings,
53+
rope_theta=getattr(config, "swa_rope_theta", rope_theta),
54+
rope_scaling=rope_scaling,
55+
head_dim=config.swa_head_dim,
56+
v_head_dim=getattr(config, "swa_v_head_dim", None),
57+
sliding_window_size=getattr(config, "sliding_window_size", None),
58+
attention_sink_bias=getattr(config, "add_swa_attention_sink_bias", False),
59+
partial_rotary_factor=getattr(config, "partial_rotary_factor", 1.0),
60+
attention_value_scale=getattr(config, "attention_value_scale", None),
61+
layer_id=layer_id,
62+
dtype=dtype,
63+
mesh=mesh,
64+
)
65+
self.mlp = MiMoV2MLP(
66+
hidden_size=config.hidden_size,
67+
intermediate_size=config.intermediate_size,
68+
layer_id=layer_id,
69+
dtype=dtype,
70+
mesh=mesh,
71+
)
72+
self.input_layernorm = RMSNorm(
73+
config.hidden_size, epsilon=config.layernorm_epsilon, param_dtype=dtype
74+
)
75+
self.post_attention_layernorm = RMSNorm(
76+
config.hidden_size, epsilon=config.layernorm_epsilon, param_dtype=dtype
77+
)
78+
79+
def __call__(
80+
self,
81+
positions: jax.Array,
82+
hidden_states: jax.Array,
83+
forward_batch: ForwardBatch,
84+
token_to_kv_pool: KVCache,
85+
residual: jax.Array | None = None,
86+
) -> tuple[jax.Array, jax.Array, jax.Array, None]:
87+
if residual is None:
88+
residual = hidden_states
89+
hidden_states = self.input_layernorm(hidden_states)
90+
else:
91+
hidden_states = hidden_states + residual
92+
residual = hidden_states
93+
hidden_states = self.input_layernorm(hidden_states)
94+
95+
hidden_states, kv_fused = self.self_attn(
96+
positions=positions,
97+
hidden_states=hidden_states,
98+
forward_batch=forward_batch,
99+
token_to_kv_pool=token_to_kv_pool,
100+
)
101+
102+
hidden_states = hidden_states + residual
103+
residual = hidden_states
104+
hidden_states = self.post_attention_layernorm(hidden_states)
105+
hidden_states = self.mlp(hidden_states)
106+
return hidden_states, residual, kv_fused, None
107+
108+
109+
class MiMoV2ModelNextN(nnx.Module):
110+
111+
def __init__(
112+
self,
113+
config: PretrainedConfig,
114+
mesh: jax.sharding.Mesh,
115+
dtype: jnp.dtype = jnp.bfloat16,
116+
):
117+
self.embed_tokens = Embed(
118+
num_embeddings=config.vocab_size,
119+
features=config.hidden_size,
120+
dtype=dtype,
121+
kernel_axes=("tensor", None),
122+
param_dtype=dtype,
123+
mesh=mesh,
124+
)
125+
self.enorm = RMSNorm(
126+
config.hidden_size, epsilon=config.layernorm_epsilon, param_dtype=dtype
127+
)
128+
self.hnorm = RMSNorm(
129+
config.hidden_size, epsilon=config.layernorm_epsilon, param_dtype=dtype
130+
)
131+
self.eh_proj = LinearBase(
132+
input_size=2 * config.hidden_size,
133+
output_size=config.hidden_size,
134+
use_bias=False,
135+
kernel_axes=(None, None),
136+
params_dtype=dtype,
137+
mesh=mesh,
138+
)
139+
self.mtp_block = MiMoV2NextNDecoderLayer(config, mesh=mesh, layer_id=0, dtype=dtype)
140+
self.final_layernorm = RMSNorm(
141+
config.hidden_size, epsilon=config.layernorm_epsilon, param_dtype=dtype
142+
)
143+
144+
def __call__(
145+
self, forward_batch: ForwardBatch, token_to_kv_pool: KVCache
146+
) -> tuple[jax.Array, list[jax.Array]]:
147+
embed = self.embed_tokens(forward_batch.input_ids)
148+
hidden_in = forward_batch.spec_info.hidden_states
149+
hidden_states, _ = self.eh_proj(
150+
jnp.concatenate((self.enorm(embed), self.hnorm(hidden_in)), axis=-1)
151+
)
152+
hidden_states, residual, kv_fused, _ = self.mtp_block(
153+
forward_batch.positions, hidden_states, forward_batch, token_to_kv_pool, None
154+
)
155+
hidden_states = self.final_layernorm(hidden_states + residual)
156+
return hidden_states, [kv_fused]
157+
158+
159+
class MiMoV2MTPForCausalLM(nnx.Module):
160+
161+
load_lm_head_from_target = True
162+
163+
def __init__(
164+
self,
165+
config: PretrainedConfig,
166+
mesh: jax.sharding.Mesh | None = None,
167+
dtype: jnp.dtype = jnp.bfloat16,
168+
):
169+
self.config = config
170+
self.mesh = mesh
171+
self.dtype = dtype
172+
self.mtp_layer_idx = getattr(config, "mtp_layer_idx", 0)
173+
self.model = MiMoV2ModelNextN(config, mesh=mesh, dtype=dtype)
174+
self.lm_head = ParallelLMHead(
175+
config.vocab_size,
176+
config.hidden_size,
177+
dtype=dtype,
178+
param_dtype=dtype,
179+
kernel_axes=("tensor", None),
180+
)
181+
self.logits_processor = LogitsProcessor(config.vocab_size, mesh=self.mesh)
182+
self._fused_qkv_buffers: dict[int, dict] = {}
183+
self.hot_token_ids = None
184+
185+
def __call__(
186+
self,
187+
forward_batch: ForwardBatch,
188+
memory_pools: MemoryPools,
189+
logits_metadata: LogitsMetadata,
190+
):
191+
hidden_states, layers_kv_fused = self.model(forward_batch, memory_pools.token_to_kv_pool)
192+
output = self.logits_processor(
193+
hidden_states, self.lm_head, logits_metadata, aux_hidden_states=None
194+
)
195+
return output, layers_kv_fused, []
196+
197+
def load_weights(self, model_config: ModelConfig):
198+
self.loader = WeightLoader(
199+
model=self, model_config=model_config, mesh=self.mesh, dtype=self.dtype
200+
)
201+
mappings = self._create_weight_mappings()
202+
self.loader.load_weights_from_safetensors(mappings)
203+
204+
if self.loader.is_static_quant:
205+
attn = self.model.mtp_block.self_attn
206+
head_dim, v_head_dim = attn.head_dim, attn.v_head_dim
207+
# dequant_fused_qkv reads full-attn config fields; the MTP block uses
208+
# SWA dims, so derive the split config from the actual layer.
209+
mtp_qkv_config = SimpleNamespace(
210+
head_dim=head_dim,
211+
v_head_dim=v_head_dim,
212+
num_attention_heads=attn.q_head_num,
213+
num_key_value_heads=attn.k_head_num,
214+
)
215+
self.loader.dequant_fused_qkv(
216+
self._fused_qkv_buffers, [self.model.mtp_block], mtp_qkv_config
217+
)
218+
self.loader.dequant_fp8_layers(
219+
[self.model.mtp_block],
220+
specs=[
221+
("mlp.gate_proj", None),
222+
("mlp.up_proj", None),
223+
("mlp.down_proj", None),
224+
],
225+
)
226+
self.loader.replicate_kv_heads(
227+
[self.model.mtp_block],
228+
specs=[("self_attn.k_proj", head_dim), ("self_attn.v_proj", v_head_dim)],
229+
target_kv_heads_fn=lambda attn: attn.k_head_num,
230+
)
231+
logger.info(
232+
"MiMoV2 MTP layer %d weights loaded (fused-qkv FP8=%s)",
233+
self.mtp_layer_idx,
234+
self.loader.is_static_quant,
235+
)
236+
237+
def _create_weight_mappings(self) -> dict[str, WeightMapping]:
238+
idx = self.mtp_layer_idx
239+
prefix = f"model.mtp.layers.{idx}"
240+
block = "model.mtp_block"
241+
is_fp8 = self.loader.is_static_quant
242+
243+
mappings: dict[str, WeightMapping] = {
244+
f"{prefix}.enorm.weight": WeightMapping(
245+
target_path="model.enorm.scale", sharding=(None,), transpose=False
246+
),
247+
f"{prefix}.hnorm.weight": WeightMapping(
248+
target_path="model.hnorm.scale", sharding=(None,), transpose=False
249+
),
250+
f"{prefix}.eh_proj.weight": WeightMapping(
251+
target_path="model.eh_proj.weight", sharding=(None, None), transpose=True
252+
),
253+
f"{prefix}.final_layernorm.weight": WeightMapping(
254+
target_path="model.final_layernorm.scale", sharding=(None,), transpose=False
255+
),
256+
f"{prefix}.input_layernorm.weight": WeightMapping(
257+
target_path=f"{block}.input_layernorm.scale", sharding=(None,), transpose=False
258+
),
259+
f"{prefix}.pre_mlp_layernorm.weight": WeightMapping(
260+
target_path=f"{block}.post_attention_layernorm.scale",
261+
sharding=(None,),
262+
transpose=False,
263+
),
264+
f"{prefix}.self_attn.o_proj.weight": WeightMapping(
265+
target_path=f"{block}.self_attn.o_proj.weight",
266+
sharding=("tensor", None),
267+
transpose=True,
268+
head_dim_padding=True,
269+
),
270+
f"{prefix}.self_attn.attention_sink_bias": WeightMapping(
271+
target_path=f"{block}.self_attn.attention_sink_bias",
272+
sharding=("tensor",),
273+
transpose=False,
274+
),
275+
}
276+
277+
qkv_key = f"{prefix}.self_attn.qkv_proj"
278+
if is_fp8 and not self.loader.is_quant_ignored(qkv_key):
279+
mappings[f"{qkv_key}.weight"] = WeightMapping(
280+
target_path="__FUSED_QKV_WEIGHT__0", sharding=(None, None), transpose=False
281+
)
282+
mappings[f"{qkv_key}.weight_scale_inv"] = WeightMapping(
283+
target_path="__FUSED_QKV_SCALE__0", sharding=(None, None), transpose=False
284+
)
285+
else:
286+
mappings[f"{qkv_key}.weight"] = WeightMapping(
287+
target_path=[
288+
f"{block}.self_attn.q_proj.weight",
289+
f"{block}.self_attn.k_proj.weight",
290+
f"{block}.self_attn.v_proj.weight",
291+
],
292+
sharding=(None, "tensor"),
293+
transpose=True,
294+
head_dim_padding=False,
295+
kv_head_padding=True,
296+
)
297+
298+
for proj, sharding in [
299+
("gate_proj", (None, "tensor")),
300+
("up_proj", (None, "tensor")),
301+
("down_proj", ("tensor", None)),
302+
]:
303+
hf_key = f"{prefix}.mlp.{proj}"
304+
suffix = "weight_q" if is_fp8 else "weight"
305+
mappings[f"{hf_key}.weight"] = WeightMapping(
306+
target_path=f"{block}.mlp.{proj}.{suffix}",
307+
sharding=sharding,
308+
transpose=True,
309+
)
310+
if is_fp8:
311+
mappings[f"{hf_key}.weight_scale_inv"] = WeightMapping(
312+
target_path=f"{block}.mlp.{proj}.weight_scale",
313+
sharding=(None, None),
314+
transpose=False,
315+
)
316+
317+
return mappings
318+
319+
def get_embed_and_head(self):
320+
return self.model.embed_tokens.embedding.value, self.lm_head.embedding.value
321+
322+
def set_embed_and_head(self, embed: jax.Array, head: jax.Array) -> None:
323+
self.model.embed_tokens.embedding.value = embed
324+
self.lm_head.embedding.value = head
325+
326+
def set_embed(self, embed: jax.Array) -> None:
327+
self.model.embed_tokens.embedding.value = embed
328+
329+
330+
EntryClass = MiMoV2MTPForCausalLM

0 commit comments

Comments
 (0)