|
| 1 | +"""MiMo-V2.5-Pro multi-token-prediction (NextN/MTP) draft model. |
| 2 | +
|
| 3 | +One instance loads exactly one of the 3 MTP layers from |
| 4 | +``model_mtp.safetensors`` (selected via ``config.mtp_layer_idx``); the |
| 5 | +``MultiLayerDraftWorker`` (#1053 P1-4) creates one instance per layer. The |
| 6 | +MTP block is a SWA-attention + dense-MLP decoder (not MoE), with the same |
| 7 | +fused-QKV per-shard FP8 layout as the V2.5-Pro target. |
| 8 | +""" |
| 9 | + |
| 10 | +import logging |
| 11 | +from types import SimpleNamespace |
| 12 | + |
| 13 | +import jax |
| 14 | +import jax.numpy as jnp |
| 15 | +from flax import nnx |
| 16 | +from transformers import PretrainedConfig |
| 17 | + |
| 18 | +from sgl_jax.srt.configs.model_config import ModelConfig |
| 19 | +from sgl_jax.srt.layers.embeddings import Embed, ParallelLMHead |
| 20 | +from sgl_jax.srt.layers.layernorm import RMSNorm |
| 21 | +from sgl_jax.srt.layers.linear import LinearBase |
| 22 | +from sgl_jax.srt.layers.logits_processor import LogitsMetadata, LogitsProcessor |
| 23 | +from sgl_jax.srt.mem_cache.memory_pool import KVCache, MemoryPools |
| 24 | +from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch |
| 25 | +from sgl_jax.srt.models.mimo_v2_flash import MiMoV2Attention, MiMoV2MLP |
| 26 | +from sgl_jax.srt.utils.weight_utils import WeightLoader, WeightMapping |
| 27 | + |
| 28 | +logger = logging.getLogger(__name__) |
| 29 | + |
| 30 | + |
| 31 | +class MiMoV2NextNDecoderLayer(nnx.Module): |
| 32 | + """Single MTP decoder block: SWA attention + dense MLP (never MoE).""" |
| 33 | + |
| 34 | + def __init__( |
| 35 | + self, |
| 36 | + config: PretrainedConfig, |
| 37 | + mesh: jax.sharding.Mesh, |
| 38 | + layer_id: int = 0, |
| 39 | + dtype: jnp.dtype = jnp.bfloat16, |
| 40 | + ): |
| 41 | + self.layer_id = layer_id |
| 42 | + rope_theta = getattr(config, "rope_theta", 1000000) |
| 43 | + rope_scaling = getattr(config, "rope_scaling", None) |
| 44 | + if isinstance(rope_scaling, dict) and rope_scaling.get("rope_type") == "default": |
| 45 | + rope_scaling = None |
| 46 | + max_position_embeddings = getattr(config, "max_position_embeddings", 32768) |
| 47 | + |
| 48 | + self.self_attn = MiMoV2Attention( |
| 49 | + hidden_size=config.hidden_size, |
| 50 | + num_heads=config.swa_num_attention_heads, |
| 51 | + num_kv_heads=config.swa_num_key_value_heads, |
| 52 | + max_position_embeddings=max_position_embeddings, |
| 53 | + rope_theta=getattr(config, "swa_rope_theta", rope_theta), |
| 54 | + rope_scaling=rope_scaling, |
| 55 | + head_dim=config.swa_head_dim, |
| 56 | + v_head_dim=getattr(config, "swa_v_head_dim", None), |
| 57 | + sliding_window_size=getattr(config, "sliding_window_size", None), |
| 58 | + attention_sink_bias=getattr(config, "add_swa_attention_sink_bias", False), |
| 59 | + partial_rotary_factor=getattr(config, "partial_rotary_factor", 1.0), |
| 60 | + attention_value_scale=getattr(config, "attention_value_scale", None), |
| 61 | + layer_id=layer_id, |
| 62 | + dtype=dtype, |
| 63 | + mesh=mesh, |
| 64 | + ) |
| 65 | + self.mlp = MiMoV2MLP( |
| 66 | + hidden_size=config.hidden_size, |
| 67 | + intermediate_size=config.intermediate_size, |
| 68 | + layer_id=layer_id, |
| 69 | + dtype=dtype, |
| 70 | + mesh=mesh, |
| 71 | + ) |
| 72 | + self.input_layernorm = RMSNorm( |
| 73 | + config.hidden_size, epsilon=config.layernorm_epsilon, param_dtype=dtype |
| 74 | + ) |
| 75 | + self.post_attention_layernorm = RMSNorm( |
| 76 | + config.hidden_size, epsilon=config.layernorm_epsilon, param_dtype=dtype |
| 77 | + ) |
| 78 | + |
| 79 | + def __call__( |
| 80 | + self, |
| 81 | + positions: jax.Array, |
| 82 | + hidden_states: jax.Array, |
| 83 | + forward_batch: ForwardBatch, |
| 84 | + token_to_kv_pool: KVCache, |
| 85 | + residual: jax.Array | None = None, |
| 86 | + ) -> tuple[jax.Array, jax.Array, jax.Array, None]: |
| 87 | + if residual is None: |
| 88 | + residual = hidden_states |
| 89 | + hidden_states = self.input_layernorm(hidden_states) |
| 90 | + else: |
| 91 | + hidden_states = hidden_states + residual |
| 92 | + residual = hidden_states |
| 93 | + hidden_states = self.input_layernorm(hidden_states) |
| 94 | + |
| 95 | + hidden_states, kv_fused = self.self_attn( |
| 96 | + positions=positions, |
| 97 | + hidden_states=hidden_states, |
| 98 | + forward_batch=forward_batch, |
| 99 | + token_to_kv_pool=token_to_kv_pool, |
| 100 | + ) |
| 101 | + |
| 102 | + hidden_states = hidden_states + residual |
| 103 | + residual = hidden_states |
| 104 | + hidden_states = self.post_attention_layernorm(hidden_states) |
| 105 | + hidden_states = self.mlp(hidden_states) |
| 106 | + return hidden_states, residual, kv_fused, None |
| 107 | + |
| 108 | + |
| 109 | +class MiMoV2ModelNextN(nnx.Module): |
| 110 | + |
| 111 | + def __init__( |
| 112 | + self, |
| 113 | + config: PretrainedConfig, |
| 114 | + mesh: jax.sharding.Mesh, |
| 115 | + dtype: jnp.dtype = jnp.bfloat16, |
| 116 | + ): |
| 117 | + self.embed_tokens = Embed( |
| 118 | + num_embeddings=config.vocab_size, |
| 119 | + features=config.hidden_size, |
| 120 | + dtype=dtype, |
| 121 | + kernel_axes=("tensor", None), |
| 122 | + param_dtype=dtype, |
| 123 | + mesh=mesh, |
| 124 | + ) |
| 125 | + self.enorm = RMSNorm( |
| 126 | + config.hidden_size, epsilon=config.layernorm_epsilon, param_dtype=dtype |
| 127 | + ) |
| 128 | + self.hnorm = RMSNorm( |
| 129 | + config.hidden_size, epsilon=config.layernorm_epsilon, param_dtype=dtype |
| 130 | + ) |
| 131 | + self.eh_proj = LinearBase( |
| 132 | + input_size=2 * config.hidden_size, |
| 133 | + output_size=config.hidden_size, |
| 134 | + use_bias=False, |
| 135 | + kernel_axes=(None, None), |
| 136 | + params_dtype=dtype, |
| 137 | + mesh=mesh, |
| 138 | + ) |
| 139 | + self.mtp_block = MiMoV2NextNDecoderLayer(config, mesh=mesh, layer_id=0, dtype=dtype) |
| 140 | + self.final_layernorm = RMSNorm( |
| 141 | + config.hidden_size, epsilon=config.layernorm_epsilon, param_dtype=dtype |
| 142 | + ) |
| 143 | + |
| 144 | + def __call__( |
| 145 | + self, forward_batch: ForwardBatch, token_to_kv_pool: KVCache |
| 146 | + ) -> tuple[jax.Array, list[jax.Array]]: |
| 147 | + embed = self.embed_tokens(forward_batch.input_ids) |
| 148 | + hidden_in = forward_batch.spec_info.hidden_states |
| 149 | + hidden_states, _ = self.eh_proj( |
| 150 | + jnp.concatenate((self.enorm(embed), self.hnorm(hidden_in)), axis=-1) |
| 151 | + ) |
| 152 | + hidden_states, residual, kv_fused, _ = self.mtp_block( |
| 153 | + forward_batch.positions, hidden_states, forward_batch, token_to_kv_pool, None |
| 154 | + ) |
| 155 | + hidden_states = self.final_layernorm(hidden_states + residual) |
| 156 | + return hidden_states, [kv_fused] |
| 157 | + |
| 158 | + |
| 159 | +class MiMoV2MTPForCausalLM(nnx.Module): |
| 160 | + |
| 161 | + load_lm_head_from_target = True |
| 162 | + |
| 163 | + def __init__( |
| 164 | + self, |
| 165 | + config: PretrainedConfig, |
| 166 | + mesh: jax.sharding.Mesh | None = None, |
| 167 | + dtype: jnp.dtype = jnp.bfloat16, |
| 168 | + ): |
| 169 | + self.config = config |
| 170 | + self.mesh = mesh |
| 171 | + self.dtype = dtype |
| 172 | + self.mtp_layer_idx = getattr(config, "mtp_layer_idx", 0) |
| 173 | + self.model = MiMoV2ModelNextN(config, mesh=mesh, dtype=dtype) |
| 174 | + self.lm_head = ParallelLMHead( |
| 175 | + config.vocab_size, |
| 176 | + config.hidden_size, |
| 177 | + dtype=dtype, |
| 178 | + param_dtype=dtype, |
| 179 | + kernel_axes=("tensor", None), |
| 180 | + ) |
| 181 | + self.logits_processor = LogitsProcessor(config.vocab_size, mesh=self.mesh) |
| 182 | + self._fused_qkv_buffers: dict[int, dict] = {} |
| 183 | + self.hot_token_ids = None |
| 184 | + |
| 185 | + def __call__( |
| 186 | + self, |
| 187 | + forward_batch: ForwardBatch, |
| 188 | + memory_pools: MemoryPools, |
| 189 | + logits_metadata: LogitsMetadata, |
| 190 | + ): |
| 191 | + hidden_states, layers_kv_fused = self.model(forward_batch, memory_pools.token_to_kv_pool) |
| 192 | + output = self.logits_processor( |
| 193 | + hidden_states, self.lm_head, logits_metadata, aux_hidden_states=None |
| 194 | + ) |
| 195 | + return output, layers_kv_fused, [] |
| 196 | + |
| 197 | + def load_weights(self, model_config: ModelConfig): |
| 198 | + self.loader = WeightLoader( |
| 199 | + model=self, model_config=model_config, mesh=self.mesh, dtype=self.dtype |
| 200 | + ) |
| 201 | + mappings = self._create_weight_mappings() |
| 202 | + self.loader.load_weights_from_safetensors(mappings) |
| 203 | + |
| 204 | + if self.loader.is_static_quant: |
| 205 | + attn = self.model.mtp_block.self_attn |
| 206 | + head_dim, v_head_dim = attn.head_dim, attn.v_head_dim |
| 207 | + # dequant_fused_qkv reads full-attn config fields; the MTP block uses |
| 208 | + # SWA dims, so derive the split config from the actual layer. |
| 209 | + mtp_qkv_config = SimpleNamespace( |
| 210 | + head_dim=head_dim, |
| 211 | + v_head_dim=v_head_dim, |
| 212 | + num_attention_heads=attn.q_head_num, |
| 213 | + num_key_value_heads=attn.k_head_num, |
| 214 | + ) |
| 215 | + self.loader.dequant_fused_qkv( |
| 216 | + self._fused_qkv_buffers, [self.model.mtp_block], mtp_qkv_config |
| 217 | + ) |
| 218 | + self.loader.dequant_fp8_layers( |
| 219 | + [self.model.mtp_block], |
| 220 | + specs=[ |
| 221 | + ("mlp.gate_proj", None), |
| 222 | + ("mlp.up_proj", None), |
| 223 | + ("mlp.down_proj", None), |
| 224 | + ], |
| 225 | + ) |
| 226 | + self.loader.replicate_kv_heads( |
| 227 | + [self.model.mtp_block], |
| 228 | + specs=[("self_attn.k_proj", head_dim), ("self_attn.v_proj", v_head_dim)], |
| 229 | + target_kv_heads_fn=lambda attn: attn.k_head_num, |
| 230 | + ) |
| 231 | + logger.info( |
| 232 | + "MiMoV2 MTP layer %d weights loaded (fused-qkv FP8=%s)", |
| 233 | + self.mtp_layer_idx, |
| 234 | + self.loader.is_static_quant, |
| 235 | + ) |
| 236 | + |
| 237 | + def _create_weight_mappings(self) -> dict[str, WeightMapping]: |
| 238 | + idx = self.mtp_layer_idx |
| 239 | + prefix = f"model.mtp.layers.{idx}" |
| 240 | + block = "model.mtp_block" |
| 241 | + is_fp8 = self.loader.is_static_quant |
| 242 | + |
| 243 | + mappings: dict[str, WeightMapping] = { |
| 244 | + f"{prefix}.enorm.weight": WeightMapping( |
| 245 | + target_path="model.enorm.scale", sharding=(None,), transpose=False |
| 246 | + ), |
| 247 | + f"{prefix}.hnorm.weight": WeightMapping( |
| 248 | + target_path="model.hnorm.scale", sharding=(None,), transpose=False |
| 249 | + ), |
| 250 | + f"{prefix}.eh_proj.weight": WeightMapping( |
| 251 | + target_path="model.eh_proj.weight", sharding=(None, None), transpose=True |
| 252 | + ), |
| 253 | + f"{prefix}.final_layernorm.weight": WeightMapping( |
| 254 | + target_path="model.final_layernorm.scale", sharding=(None,), transpose=False |
| 255 | + ), |
| 256 | + f"{prefix}.input_layernorm.weight": WeightMapping( |
| 257 | + target_path=f"{block}.input_layernorm.scale", sharding=(None,), transpose=False |
| 258 | + ), |
| 259 | + f"{prefix}.pre_mlp_layernorm.weight": WeightMapping( |
| 260 | + target_path=f"{block}.post_attention_layernorm.scale", |
| 261 | + sharding=(None,), |
| 262 | + transpose=False, |
| 263 | + ), |
| 264 | + f"{prefix}.self_attn.o_proj.weight": WeightMapping( |
| 265 | + target_path=f"{block}.self_attn.o_proj.weight", |
| 266 | + sharding=("tensor", None), |
| 267 | + transpose=True, |
| 268 | + head_dim_padding=True, |
| 269 | + ), |
| 270 | + f"{prefix}.self_attn.attention_sink_bias": WeightMapping( |
| 271 | + target_path=f"{block}.self_attn.attention_sink_bias", |
| 272 | + sharding=("tensor",), |
| 273 | + transpose=False, |
| 274 | + ), |
| 275 | + } |
| 276 | + |
| 277 | + qkv_key = f"{prefix}.self_attn.qkv_proj" |
| 278 | + if is_fp8 and not self.loader.is_quant_ignored(qkv_key): |
| 279 | + mappings[f"{qkv_key}.weight"] = WeightMapping( |
| 280 | + target_path="__FUSED_QKV_WEIGHT__0", sharding=(None, None), transpose=False |
| 281 | + ) |
| 282 | + mappings[f"{qkv_key}.weight_scale_inv"] = WeightMapping( |
| 283 | + target_path="__FUSED_QKV_SCALE__0", sharding=(None, None), transpose=False |
| 284 | + ) |
| 285 | + else: |
| 286 | + mappings[f"{qkv_key}.weight"] = WeightMapping( |
| 287 | + target_path=[ |
| 288 | + f"{block}.self_attn.q_proj.weight", |
| 289 | + f"{block}.self_attn.k_proj.weight", |
| 290 | + f"{block}.self_attn.v_proj.weight", |
| 291 | + ], |
| 292 | + sharding=(None, "tensor"), |
| 293 | + transpose=True, |
| 294 | + head_dim_padding=False, |
| 295 | + kv_head_padding=True, |
| 296 | + ) |
| 297 | + |
| 298 | + for proj, sharding in [ |
| 299 | + ("gate_proj", (None, "tensor")), |
| 300 | + ("up_proj", (None, "tensor")), |
| 301 | + ("down_proj", ("tensor", None)), |
| 302 | + ]: |
| 303 | + hf_key = f"{prefix}.mlp.{proj}" |
| 304 | + suffix = "weight_q" if is_fp8 else "weight" |
| 305 | + mappings[f"{hf_key}.weight"] = WeightMapping( |
| 306 | + target_path=f"{block}.mlp.{proj}.{suffix}", |
| 307 | + sharding=sharding, |
| 308 | + transpose=True, |
| 309 | + ) |
| 310 | + if is_fp8: |
| 311 | + mappings[f"{hf_key}.weight_scale_inv"] = WeightMapping( |
| 312 | + target_path=f"{block}.mlp.{proj}.weight_scale", |
| 313 | + sharding=(None, None), |
| 314 | + transpose=False, |
| 315 | + ) |
| 316 | + |
| 317 | + return mappings |
| 318 | + |
| 319 | + def get_embed_and_head(self): |
| 320 | + return self.model.embed_tokens.embedding.value, self.lm_head.embedding.value |
| 321 | + |
| 322 | + def set_embed_and_head(self, embed: jax.Array, head: jax.Array) -> None: |
| 323 | + self.model.embed_tokens.embedding.value = embed |
| 324 | + self.lm_head.embedding.value = head |
| 325 | + |
| 326 | + def set_embed(self, embed: jax.Array) -> None: |
| 327 | + self.model.embed_tokens.embedding.value = embed |
| 328 | + |
| 329 | + |
| 330 | +EntryClass = MiMoV2MTPForCausalLM |
0 commit comments