Skip to content

[Feature]【Hackathon 10th Spring No.47】Add MiniMax-M1 model for FastDeploy#7510

Open
bobby-cloudforge wants to merge 2 commits intoPaddlePaddle:developfrom
CloudForge-Solutions:task/h10-047-minimax-m1-model1
Open

[Feature]【Hackathon 10th Spring No.47】Add MiniMax-M1 model for FastDeploy#7510
bobby-cloudforge wants to merge 2 commits intoPaddlePaddle:developfrom
CloudForge-Solutions:task/h10-047-minimax-m1-model1

Conversation

@bobby-cloudforge
Copy link
Copy Markdown

Motivation

为 FastDeploy 增加部署 MiniMaxAI/MiniMax-M1-40k 系列模型的能力。

This PR adds support for deploying the MiniMax-M1 (456B MoE, 45.9B active) model family in FastDeploy, as required by Hackathon 10th Spring No.47.

MiniMax-M1 is a hybrid-attention Mixture-of-Experts LLM with:

  • Lightning Attention: 70 out of 80 layers use linear-complexity attention (O(n) vs O(n²))
  • Full GQA: 10 layers (indices 7,15,23,31,39,47,55,63,71,79) use standard grouped-query attention
  • MoE: 32 experts with top-2 routing per token
  • DeepNorm: Separate alpha/beta scaling for linear vs full attention layers
  • Postnorm: Residual carries normed activations (differs from standard pre-norm)
  • Architecture registered as both MiniMaxM1ForCausalLM and MiniMaxText01ForCausalLM

Modifications

Model Code (fastdeploy/model_executor/models/minimax_m1.py, ~800 lines)

9 classes implementing the full model:

  • MiniMaxM1MLP: Gate/up merged projection with SiLU activation
  • MiniMaxM1MoE: FusedMoE with 32 experts, top-2 routing, renormalize=True, quantization-aware weight_key_map (w4a8, w4afp8 static/dynamic, tensor_wise_fp8, block_wise_fp8)
  • MiniMaxM1FullAttention: Standard GQA with RoPE, used in 10 out of 80 layers
  • MiniMaxM1LinearAttention: Lightning attention with SiLU-gated QKV, output_gate (sigmoid), RMSNorm, persistent KV state history
  • MiniMaxM1DecoderLayer: Dispatches to linear/full attention based on attn_type_list, DeepNorm scaling with separate alpha/beta per attention type, postnorm support
  • MiniMaxM1Model: Full transformer with embedding and final RMSNorm
  • MiniMaxM1ForCausalLM: Causal LM wrapper with dual weight loading (v0 set_state_dict + v1 load_weights)
  • MiniMaxM1PretrainedModel: Tensor parallel column/row split mappings

Lightning Attention Kernels (fastdeploy/model_executor/ops/triton_ops/lightning_attn.py, 711 lines)

Triton kernels for O(n) linear attention with exponential decay:

  • _fwd_kernel: Intra-block attention with causal masking and decay factors
  • _fwd_kv_kernel: Inter-block KV state accumulation with block-level decay
  • lightning_attention(): Python wrapper dispatching to Triton with automatic block size, dtype management, and KV history persistence

Documentation

  • docs/best_practices/MiniMax-M1.md + docs/zh/best_practices/MiniMax-M1.md: Bilingual deployment guide
  • docs/supported_models.md + docs/zh/supported_models.md: Added MiniMax-M1 to LLM model table

Usage or Command

python -m fastdeploy.entrypoints.openai.api_server \
       --model MiniMaxAI/MiniMax-M1-40k \
       --tensor-parallel-size 8 \
       --max-model-len 40960 \
       --max-num-seqs 64

See docs/best_practices/MiniMax-M1.md for full deployment guide.

Accuracy Tests

Unit Tests (36/36 passed — CI verified on H20 GPU)

  • Test file: tests/model_executor/test_minimax_m1.py (576 lines, 36 function-based tests)
  • Pure-logic tests (12): _build_attn_type_list correctness, _build_slope_tensor shape and values
  • Registration tests (5): Primary + alias architecture in ModelRegistry, class identity
  • Construction tests (9): Linear vs full attention dispatch, DeepNorm defaults, MoE vs dense MLP, quantization weight_key_map
  • Forward-pass smoke tests (4): Linear/full-attn layer output shapes, DeepNorm scaling, postnorm path
  • Lightning Attention NumPy reference (4): Single-token, multi-token causal, KV persistence, multi-head independence
  • Style: Function-based pytest, monkeypatch.setattr + lightweight stubs

Pre-commit Validation

All hooks passing: black, isort, flake8, ruff, clang-format, trailing whitespace, large file check.

Checklist

  • Model code (minimax_m1.py, ~800 lines) — 9 classes with full weight loading + quantization
  • Lightning Attention Triton kernels (lightning_attn.py, 711 lines)
  • Unit tests (36/36 passing, ~576 lines) — function-based pytest with NumPy reference
  • Low-bit quantization: w4a8, w4afp8 (static/dynamic), tensor_wise_fp8, block_wise_fp8
  • Documentation (EN + CN best practices, supported models)
  • HF weight key mapping verified
  • Both v0 and v1 loader paths implemented
  • Dual architecture registration: MiniMaxM1ForCausalLM + MiniMaxText01ForCausalLM
  • Pre-commit hooks all passing

@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 20, 2026

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label Apr 20, 2026
PaddlePaddle-bot

This comment was marked as outdated.

- rotary_embedding.py: use 'MiniMax' prefix to match both MiniMaxM1 and
  MiniMaxText01 architectures (was missing HF alias → wrong RoPE)
- test_minimax_m1.py: assert residual is None (DeepNorm folds residual
  into hidden_states, so decoder returns None)
Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-20 17:05:29\n\n## 📋 Review 摘要\n\nPR 概述:为 FastDeploy 新增 MiniMax-M1(456B MoE, 70 层线性注意力 + 10 层全注意力)模型支持,包含 Lightning Attention Triton kernel、双路径权重加载和双架构注册。\n变更范围model_executor/models/model_executor/ops/triton_ops/model_executor/layers/rotary_embedding.pydocs/tests/\n影响面 TagModels OP\n\n### 问题\n\n| 级别 | 文件 | 概述 |\n|------|------|------|\n| 🔴 Bug | minimax_m1.py:358 | 线性注意力 KV history 使用实例变量存储,多请求并发时状态互相污染 |\n| 🟡 建议 | minimax_m1.py:685 | v0 set_state_dict 与 v1 load_weights 处理逻辑不一致,且缺少不完整 buffer 警告 |\n| 🟡 建议 | lightning_attn.py:576 | _fwd_none_diag_kernel grid 维度与 kernel 内部 program_id(2) 不匹配 |\n| ❓ 疑问 | rotary_embedding.py:344 | MiniMax 复用 Qwen RoPE 路径是否确认兼容 |\n\n### 总体评价\n\n整体模型实现架构清晰,MoE、DeepNorm、双权重加载路径等复杂特性覆盖较完整。核心阻塞问题是 Linear Attention 的 _kv_history 实例变量存储方式在并发 serving 场景下会导致正确性问题,建议在 merge 前至少添加显式的限制说明或 warning。Triton kernel 实现质量较高,但 grid 维度配置需与 kernel 保持一致。"

# Retrieve or initialize KV history for recurrent state persistence.
# TODO: Migrate to ForwardMeta.caches / slot-based cache management for
# proper multi-request isolation in production serving scenarios.
if not hasattr(self, "_kv_history") or self._kv_history is None or self._kv_history.shape[0] != batch_size:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug 线性注意力的 KV history 使用实例变量 _kv_history 存储,在多请求并发 serving 场景下会导致不同请求共享同一份 KV 状态,产生错误的推理结果。

虽然代码中已标注 TODO,但这是一个阻塞性的正确性问题:FastDeploy 作为推理服务框架,并发 serving 是核心场景。当前实现下:

  1. 状态互相污染:request A 的 KV history 会被 request B 覆盖,导致两个请求的注意力输出都是错误的
  2. 状态丢失batch_size 变化时会直接用零张量重置 _kv_history(丢弃所有历史状态),这在 continuous batching 场景下会频繁触发

建议:

  1. 将 KV history 迁移至 ForwardMeta.caches 或 slot-based cache,确保每个请求有独立的 KV 状态
  2. 至少在本 PR 中添加 logger.warning 提示当前实现仅支持单请求场景,以避免在生产中被误用

name = re.sub(r"\.w2\.", ".down_proj.", name)
renamed[name] = weight
# Full attention: merge separate q/k/v into qkv_proj
elif ".self_attn.q_proj." in name or ".self_attn.k_proj." in name or ".self_attn.v_proj." in name:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 set_state_dict(v0 路径)与 load_weights(v1 路径)处理 q/k/v 权重合并的逻辑不一致。

  • v1 路径明确区分了 linear attention 层(使用 concat,无 shard_id)和 full attention 层(使用 shard_id),还对不完整的 qkv buffer 打印 warning
  • v0 路径对所有层统一做 concat([q,k,v], axis=0),且不完整 buffer 被静默丢弃

当前碰巧结果正确(两种 qkv_proj 都期望 axis=0 concat),但建议:

  1. 添加与 v1 相同的不完整 buffer warning 日志
  2. 在注释中说明为何 v0 路径不需要区分 linear/full attention 的原因,避免未来维护者混淆

)

# Step 4: Compute non-diagonal blocks of attention
grid = (b * h, NUM_BLOCK * NUM_CBLOCK)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 _fwd_none_diag_kernel 内部使用了 tl.program_id(2) 来索引 off_e(feature block 偏移),但此处 grid 只有二维 (b * h, NUM_BLOCK * NUM_CBLOCK)

当前 NUM_FBLOCK = 1program_id(2) 默认返回 0,结果正确。但这种 grid 维度与 kernel 内部 program_id 不匹配的写法存在隐患:

  • 如果未来 NUM_FBLOCK > 1,kernel 只会执行第一个 feature block,导致静默的计算错误
  • 代码可读性差——阅读 kernel 代码会以为有第三维并行

建议显式使用三维 grid:

grid = (b * h, NUM_BLOCK * NUM_CBLOCK, NUM_FBLOCK)


architecture = model_config.architectures[0]
if architecture.startswith("Qwen"):
if architecture.startswith("Qwen") or architecture.startswith("MiniMax"):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 MiniMax-M1 的全注意力层复用了 QwenRotaryEmbedding,请确认 MiniMax-M1 的 RoPE 实现与 Qwen 系列完全兼容(包括 rope_thetarope_scaling 等参数处理方式)。

另外,architecture.startswith("MiniMax") 匹配范围较广,未来若有其他 MiniMax 架构使用不同 RoPE 实现,建议使用更精确的匹配(如 architecture in ("MiniMaxM1ForCausalLM", "MiniMaxText01ForCausalLM"))。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants