-
Notifications
You must be signed in to change notification settings - Fork 444
[MoE] MiniMax-M2/M2.1 calibration follow-up #2335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
kylesayrs
merged 13 commits into
vllm-project:main
from
LudovicoYIN:moe/minimax-m2-calibration-followup
Mar 18, 2026
+332
−0
Merged
Changes from 5 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
adaf58f
[MoE] Align MiniMaxM2 non-calibration path with upstream experts
LudovicoYIN 63f25ba
fix import error
LudovicoYIN d205050
fix example
LudovicoYIN 41865b8
fp8 model to bf16
LudovicoYIN ea45d08
Fix MiniMax M2 MoE calibration and example
LudovicoYIN 2fed675
Merge branch 'main' into moe/minimax-m2-calibration-followup
LudovicoYIN 4b6e476
Merge branch 'main' into moe/minimax-m2-calibration-followup
LudovicoYIN fc922b9
update test and example
LudovicoYIN 149d383
revert minimax_m2 transformers compatibility adaption; fix lint
LudovicoYIN 0ce22bd
Merge branch 'main' into moe/minimax-m2-calibration-followup
HDCharles 457dbaa
Merge branch 'main' into moe/minimax-m2-calibration-followup
HDCharles ef161c0
Merge branch 'main' into moe/minimax-m2-calibration-followup
LudovicoYIN b1bc069
Merge branch 'main' into moe/minimax-m2-calibration-followup
kylesayrs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| import torch | ||
| from datasets import load_dataset | ||
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modeling.minimax_m2 import ( # noqa: F401 | ||
| CalibrationMiniMaxM2SparseMoeBlock, | ||
| ) | ||
| from llmcompressor.modifiers.awq import AWQMapping, AWQModifier | ||
|
|
||
| # Load the model | ||
| model_id = "ludovicoYIN/MiniMax-M2-BF16" | ||
| config = AutoConfig.from_pretrained(model_id) | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| model_id, dtype=torch.bfloat16, config=config | ||
| ) | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
| # MoE calibration is handled automatically by the pipeline. | ||
| # The `CalibrationMiniMaxM2SparseMoeBlock` modules (from | ||
| # `llmcompressor.modeling.minimax_m2`) will be applied during calibration to enable | ||
| # proper expert calibration. These replace the original | ||
| # `MiniMaxM2SparseMoeBlock` class from | ||
| # `transformers.models.minimax_m2.modeling_minimax_m2`. | ||
|
|
||
| # Select calibration dataset. | ||
| DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
| DATASET_SPLIT = "train_sft" | ||
|
|
||
| # Select number of samples. 512 samples is a good place to start. | ||
| # Increasing the number of samples can improve accuracy. | ||
| NUM_CALIBRATION_SAMPLES = 512 | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
|
|
||
| # Load dataset and preprocess. | ||
| ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | ||
| ds = ds.shuffle(seed=42) | ||
|
|
||
|
|
||
| def preprocess(example): | ||
| return { | ||
| "text": tokenizer.apply_chat_template( | ||
| example["messages"], | ||
| tokenize=False, | ||
| ) | ||
| } | ||
|
|
||
|
|
||
| ds = ds.map(preprocess) | ||
|
|
||
|
|
||
| # Tokenize inputs. | ||
| def tokenize(sample): | ||
| return tokenizer( | ||
| sample["text"], | ||
| padding=False, | ||
| max_length=MAX_SEQUENCE_LENGTH, | ||
| truncation=True, | ||
| add_special_tokens=False, | ||
| ) | ||
|
|
||
|
|
||
| ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
|
||
| moe_ignores = [ | ||
| "lm_head", | ||
| "re:.*block_sparse_moe.gate$", | ||
| ] | ||
|
|
||
| # Experts live under `model.layers.*.block_sparse_moe.experts.<idx>.(w1|w2|w3)`. | ||
| EXPERT_TARGET_REGEX = [ | ||
| "re:.*block_sparse_moe\\.experts\\.\\d+\\.w1$", | ||
| "re:.*block_sparse_moe\\.experts\\.\\d+\\.w2$", | ||
| "re:.*block_sparse_moe\\.experts\\.\\d+\\.w3$", | ||
| ] | ||
|
|
||
|
|
||
| recipe = AWQModifier( | ||
| targets=EXPERT_TARGET_REGEX, | ||
| scheme="W4A16", | ||
| ignore=moe_ignores, | ||
| mappings=[ | ||
| AWQMapping( | ||
| "re:.*post_attention_layernorm$", | ||
| ["re:.*w1$", "re:.*w3$"], | ||
| ), | ||
| AWQMapping("re:.*w3$", ["re:.*w2$"]), | ||
| ], | ||
| duo_scaling=False, | ||
| ) | ||
|
|
||
| # Apply algorithms. | ||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| processor=tokenizer, | ||
| recipe=recipe, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| sequential_targets=["MiniMaxM2DecoderLayer"], | ||
| ) | ||
|
|
||
| # Save to disk compressed. | ||
| SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16" | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from llmcompressor.modeling.moe_context import MoECalibrationModule | ||
|
|
||
| if TYPE_CHECKING: | ||
| from transformers import MiniMaxM2Config | ||
| from transformers.models.minimax_m2.modeling_minimax_m2 import ( | ||
| MiniMaxM2SparseMoeBlock, | ||
| ) | ||
|
|
||
|
|
||
| @MoECalibrationModule.register("MiniMaxM2SparseMoeBlock") | ||
| class CalibrationMiniMaxM2SparseMoeBlock(MoECalibrationModule): | ||
| """Calibration module for MiniMaxM2SparseMoeBlock that supports calibrating all experts.""" | ||
|
|
||
| is_permanent = False | ||
|
|
||
| def __init__( | ||
| self, | ||
| original: MiniMaxM2SparseMoeBlock, | ||
| config: MiniMaxM2Config, | ||
| calibrate_all_experts: bool = True, | ||
| ): | ||
| super().__init__() | ||
|
|
||
| # Gating | ||
| self.calibrate_all_experts = calibrate_all_experts | ||
|
|
||
| # Extract submodules directly to prevent parameter duplication | ||
| # in find_tied_parameters (caused by holding self.original) | ||
| self.gate = original.gate | ||
| self.experts = original.experts | ||
|
|
||
| # MiniMax specific parameters | ||
| self.jitter_noise = original.jitter_noise | ||
| self.num_experts = config.num_local_experts | ||
| self.top_k = original.top_k | ||
| # Use unbound function so this module's buffers are used. | ||
| self._route_tokens_to_experts = type(original).route_tokens_to_experts | ||
| self.register_buffer( | ||
| "e_score_correction_bias", original.e_score_correction_bias | ||
| ) | ||
|
|
||
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Forward pass with optional all-expert calibration mode. | ||
|
|
||
| - `calibrate_all_experts=False`: use upstream expert execution path. | ||
| - `calibrate_all_experts=True`: execute every expert on all tokens, | ||
| then aggregate only routed-token outputs. | ||
| """ | ||
| batch_size, sequence_length, hidden_dim = hidden_states.shape | ||
| if self.training and self.jitter_noise > 0: | ||
| hidden_states *= torch.empty_like(hidden_states).uniform_( | ||
| 1.0 - self.jitter_noise, 1.0 + self.jitter_noise | ||
| ) | ||
| hidden_states = hidden_states.view(-1, hidden_dim) | ||
| router_logits = self.gate(hidden_states) | ||
| if self.e_score_correction_bias.device != router_logits.device: | ||
| self.e_score_correction_bias = self.e_score_correction_bias.to(router_logits.device) | ||
| top_k_index, top_k_weights = self._route_tokens_to_experts(self, router_logits) | ||
|
|
||
| final_hidden_states = torch.zeros( | ||
| (batch_size * sequence_length, hidden_dim), | ||
| dtype=hidden_states.dtype, | ||
| device=hidden_states.device, | ||
| ) | ||
| expert_mask = F.one_hot(top_k_index, num_classes=self.num_experts) | ||
| expert_mask = expert_mask.permute(2, 1, 0) | ||
|
|
||
| for expert_idx, expert_layer in enumerate(self.experts): | ||
| top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) | ||
|
|
||
| if self.calibrate_all_experts: | ||
| expert_out = expert_layer(hidden_states)[token_idx] | ||
| else: | ||
| expert_out = expert_layer(hidden_states[token_idx]) | ||
|
|
||
| if token_idx.numel() > 0: | ||
| expert_weights = top_k_weights[token_idx, top_k_pos] | ||
| weighted_output = expert_out * expert_weights.unsqueeze(-1) | ||
| final_hidden_states.index_add_( | ||
| 0, | ||
| token_idx, | ||
| weighted_output.to(hidden_states.dtype), | ||
| ) | ||
|
|
||
| final_hidden_states = final_hidden_states.reshape( | ||
| batch_size, sequence_length, hidden_dim | ||
| ) | ||
|
|
||
| return final_hidden_states, router_logits | ||
|
|
||
| def restore(self, original: torch.nn.Module) -> torch.nn.Module: | ||
| return original | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,145 @@ | ||
| import contextlib | ||
| from unittest import mock | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.nn.functional as F | ||
| from transformers import AutoModelForCausalLM | ||
|
|
||
| from llmcompressor.modeling.minimax_m2 import CalibrationMiniMaxM2SparseMoeBlock | ||
| from llmcompressor.modeling.moe_context import moe_calibration_context | ||
| from llmcompressor.utils.dev import skip_weights_download | ||
| from llmcompressor.utils.helpers import calibration_forward_context | ||
| from tests.testing_utils import requires_cadence, requires_gpu | ||
|
|
||
| MiniMaxM2Config = pytest.importorskip( | ||
| "transformers.models.minimax_m2.configuration_minimax_m2", | ||
| reason="MiniMaxM2Config not available in this version of transformers", | ||
| ).MiniMaxM2Config | ||
| MiniMaxM2SparseMoeBlock = pytest.importorskip( | ||
| "transformers.models.minimax_m2.modeling_minimax_m2", | ||
| reason="MiniMaxM2SparseMoeBlock not available in this version of transformers", | ||
| ).MiniMaxM2SparseMoeBlock | ||
|
|
||
|
|
||
| @requires_cadence("weekly") | ||
| @pytest.mark.parametrize("model_stub", ["hf-internal-testing/MiniMax-M2-Small"]) | ||
| def test_calib_replace_minimax_m2_all_experts(model_stub): | ||
| with skip_weights_download(): | ||
| model = AutoModelForCausalLM.from_pretrained(model_stub) | ||
|
|
||
| with contextlib.ExitStack() as stack: | ||
| stack.enter_context(calibration_forward_context(model)) | ||
| stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) | ||
|
|
||
| moe_layer = None | ||
| for _, module in model.named_modules(): | ||
| if isinstance(module, CalibrationMiniMaxM2SparseMoeBlock): | ||
| moe_layer = module | ||
| break | ||
|
|
||
| assert moe_layer is not None | ||
|
|
||
| num_experts = moe_layer.experts.num_experts | ||
| seen_gate = [False for _ in range(num_experts)] | ||
| seen_down = [False for _ in range(num_experts)] | ||
| gate_ptrs = { | ||
| moe_layer.experts.gate_up_proj[i].data_ptr(): i for i in range(num_experts) | ||
| } | ||
| down_ptrs = { | ||
| moe_layer.experts.down_proj[i].data_ptr(): i for i in range(num_experts) | ||
| } | ||
|
|
||
| original_linear = F.linear | ||
|
|
||
| def patched_linear(input, weight, *args, **kwargs): | ||
| ptr = weight.data_ptr() | ||
| if ptr in gate_ptrs: | ||
| seen_gate[gate_ptrs[ptr]] = True | ||
| if ptr in down_ptrs: | ||
| seen_down[down_ptrs[ptr]] = True | ||
| return original_linear(input, weight, *args, **kwargs) | ||
|
|
||
| # Create dummy input tensor that simulates hidden_states | ||
| hidden_dim = model.config.hidden_size | ||
| batch, seq_len = 2, 8 | ||
| sample = torch.randn( | ||
| batch, | ||
| seq_len, | ||
| hidden_dim, | ||
| dtype=moe_layer.experts.gate_up_proj.dtype, | ||
| device=moe_layer.experts.gate_up_proj.device, | ||
| ) | ||
|
|
||
| with torch.no_grad(): | ||
| F.linear = patched_linear # patch only within this scope | ||
| try: | ||
| _ = moe_layer(sample) | ||
| finally: | ||
| F.linear = original_linear | ||
|
|
||
| assert all(seen_gate), f"Not all experts were run (gate_up): {seen_gate}" | ||
| assert all(seen_down), f"Not all experts were run (down_proj): {seen_down}" | ||
|
|
||
|
|
||
| @requires_gpu | ||
| def test_calib_minimax_m2_module(): | ||
| config = MiniMaxM2Config( | ||
| hidden_size=16, | ||
| intermediate_size=8, | ||
| num_hidden_layers=1, | ||
| num_attention_heads=4, | ||
| num_key_value_heads=1, | ||
| head_dim=4, | ||
| max_position_embeddings=64, | ||
| num_experts_per_tok=2, | ||
| num_local_experts=4, | ||
| ) | ||
| with torch.device("cuda"): | ||
| original = MiniMaxM2SparseMoeBlock(config).eval() | ||
|
|
||
| hidden_dim = config.hidden_size | ||
| sample = torch.randn(2, 4, hidden_dim, device="cuda") | ||
|
|
||
| with calibration_forward_context(original): | ||
| true_output = original(sample) | ||
|
|
||
| module = CalibrationMiniMaxM2SparseMoeBlock(original, config, True) | ||
| with calibration_forward_context(module): | ||
| output = module(sample) | ||
| assert torch.allclose(true_output, output, atol=1e-5) | ||
|
|
||
| module = CalibrationMiniMaxM2SparseMoeBlock(original, config, False) | ||
| with calibration_forward_context(module): | ||
| output = module(sample) | ||
| assert torch.allclose(true_output, output, atol=1e-5) | ||
|
|
||
|
|
||
| def test_calib_minimax_m2_uses_upstream_experts_when_not_calibrating_all(): | ||
| config = MiniMaxM2Config( | ||
| hidden_size=16, | ||
| intermediate_size=8, | ||
| num_hidden_layers=1, | ||
| num_attention_heads=4, | ||
| num_key_value_heads=1, | ||
| head_dim=4, | ||
| max_position_embeddings=64, | ||
| num_experts_per_tok=2, | ||
| num_local_experts=4, | ||
| ) | ||
| original = MiniMaxM2SparseMoeBlock(config).eval() | ||
| module = CalibrationMiniMaxM2SparseMoeBlock(original, config, False) | ||
|
|
||
| sample = torch.randn(2, 4, config.hidden_size) | ||
|
|
||
| with calibration_forward_context(original): | ||
| true_output = original(sample) | ||
|
|
||
| with mock.patch.object( | ||
| module.experts, "forward", wraps=module.experts.forward | ||
| ) as mocked_forward: | ||
| with calibration_forward_context(module): | ||
| output = module(sample) | ||
|
|
||
| assert mocked_forward.call_count == 1 | ||
| assert torch.allclose(true_output, output, atol=1e-5) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.