Skip to content
94 changes: 94 additions & 0 deletions examples/quantizing_moe/minimax_m2_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modeling.minimax_m2 import ( # noqa: F401
CalibrationMiniMaxM2SparseMoeBlock,
)
from llmcompressor.modifiers.awq import AWQMapping, AWQModifier

# Load the model
model_id = "MiniMaxAI/MiniMax-M2"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
# MoE calibration is handled automatically by the pipeline.
# The `CalibrationMiniMaxM2SparseMoeBlock` modules (from
# `llmcompressor.modeling.minimax_m2`) will be applied during calibration to enable
# proper expert calibration. These replace the original
# `MiniMaxM2SparseMoeBlock` class from
# `transformers.models.minimax_m2.modeling_minimax_m2`.

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(tokenize, remove_columns=ds.column_names)

moe_ignores = [
# MoE gate layers are sensitive to quantization.
"re:.*mlp.gate$",
# Ignore the output head.
"lm_head",
]

# Configure the quantization algorithm to run.
recipe = AWQModifier(
targets="Linear",
scheme="W4A16",
ignore=moe_ignores,
mappings=[
AWQMapping(
"re:.*input_layernorm$",
["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"],
)
],
)


# Apply algorithms.
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
sequential_targets=["MiniMaxM2DecoderLayer"],
)

# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
5 changes: 5 additions & 0 deletions src/llmcompressor/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
from .deepseek_v3 import CalibrationDeepseekV3MoE # noqa: F401
from .glm4_moe import CalibrationGlm4MoeMoE # noqa: F401
from .llama4 import SequentialLlama4TextMoe # noqa: F401

try: # Optional dependency: transformers may not include minimax_m2 yet.
from .minimax_m2 import CalibrationMiniMaxM2SparseMoeBlock # noqa: F401
except (ImportError, ModuleNotFoundError): # pragma: no cover
pass
from .qwen3_moe import CalibrationQwen3MoeSparseMoeBlock # noqa: F401
from .qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock # noqa: F401
from .qwen3_next_moe import CalibrationQwen3NextSparseMoeBlock # noqa: F401
Expand Down
110 changes: 110 additions & 0 deletions src/llmcompressor/modeling/minimax_m2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import torch
import torch.nn.functional as F

from llmcompressor.modeling.moe_context import MoECalibrationModule

if TYPE_CHECKING:
from transformers import MiniMaxM2Config
from transformers.models.minimax_m2.modeling_minimax_m2 import (
MiniMaxM2SparseMoeBlock,
)


@MoECalibrationModule.register("MiniMaxM2SparseMoeBlock")
class CalibrationMiniMaxM2SparseMoeBlock(MoECalibrationModule):
"""
Calibration version of MiniMaxM2SparseMoeBlock that can send all tokens
to all experts during calibration.

When `calibrate_all_experts=True`, each expert is executed on all tokens so
quantization statistics are collected for every expert, while routed-token
weighting is still used for the final output.
"""

is_permanent = False

def __init__(
self,
original: MiniMaxM2SparseMoeBlock,
config: MiniMaxM2Config,
calibrate_all_experts: bool = True,
):
super().__init__()
self.config = config
self.experts = original.experts
self.gate = original.gate
self.calibrate_all_experts = calibrate_all_experts
self.jitter_noise = original.jitter_noise
self.register_buffer(
"e_score_correction_bias", original.e_score_correction_bias
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Forward pass with optional all-expert calibration mode.

- `calibrate_all_experts=False`: use upstream expert execution path.
- `calibrate_all_experts=True`: execute every expert on all tokens,
then aggregate only routed-token outputs.
"""
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
)
hidden_states = hidden_states.view(-1, hidden_dim)
_router_logits, top_k_weights, top_k_index = self.gate(
hidden_states, self.e_score_correction_bias
)

if not self.calibrate_all_experts:
final_hidden_states = self.experts(
hidden_states, top_k_index, top_k_weights
)
return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)

# Reimplementing MiniMaxM2Experts.forward only when calibrating all experts.
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
expert_mask = F.one_hot(top_k_index, num_classes=self.experts.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)

for expert_idx in range(self.experts.num_experts):
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
has_tokens = token_idx.numel() > 0

# Run all tokens through the expert to gather stats.
gate, up = F.linear(
hidden_states, self.experts.gate_up_proj[expert_idx]
).chunk(2, dim=-1)
expert_out_full = self.experts.act_fn(gate) * up
expert_out_full = F.linear(
expert_out_full, self.experts.down_proj[expert_idx]
)
if not has_tokens:
continue
expert_out = expert_out_full[token_idx]

expert_weights = top_k_weights[token_idx, top_k_pos]
weighted_output = expert_out * expert_weights.unsqueeze(-1)
final_hidden_states.index_add_(
0,
token_idx,
weighted_output.to(hidden_states.dtype),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The weighted_output tensor should already have the same dtype as hidden_states, as it's derived from operations on tensors that originate from hidden_states and top_k_weights. Therefore, the .to(hidden_states.dtype) call appears to be redundant and can be removed for a minor cleanup.

Suggested change
weighted_output.to(hidden_states.dtype),
weighted_output,

)

final_hidden_states = final_hidden_states.reshape(
batch_size, sequence_length, hidden_dim
)

return final_hidden_states

Comment on lines +49 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The forward method is quite long and contains complex logic, especially for the calibrate_all_experts=True case. To improve readability and maintainability, consider extracting the calibration logic (lines 70-101) into a separate private helper method, e.g., _forward_calibrate_all_experts. The main forward method would then act as a clearer dispatcher between the standard and calibration forward paths.

def restore(self, original: torch.nn.Module) -> torch.nn.Module:
return original
6 changes: 5 additions & 1 deletion src/llmcompressor/utils/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from loguru import logger
from safetensors.torch import save_file
from transformers import AutoModelForCausalLM, PreTrainedModel
from transformers.modeling_utils import TORCH_INIT_FUNCTIONS

try:
from transformers.modeling_utils import TORCH_INIT_FUNCTIONS
except ImportError: # transformers>=5 moved this
from transformers.initialization import TORCH_INIT_FUNCTIONS
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME

__all__ = [
Expand Down
145 changes: 145 additions & 0 deletions tests/llmcompressor/modeling/test_calib_minimax_m2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import contextlib
from unittest import mock

import pytest
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM

from llmcompressor.modeling.minimax_m2 import CalibrationMiniMaxM2SparseMoeBlock
from llmcompressor.modeling.moe_context import moe_calibration_context
from llmcompressor.utils.dev import skip_weights_download
from llmcompressor.utils.helpers import calibration_forward_context
from tests.testing_utils import requires_cadence, requires_gpu

MiniMaxM2Config = pytest.importorskip(
"transformers.models.minimax_m2.configuration_minimax_m2",
reason="MiniMaxM2Config not available in this version of transformers",
).MiniMaxM2Config
MiniMaxM2SparseMoeBlock = pytest.importorskip(
"transformers.models.minimax_m2.modeling_minimax_m2",
reason="MiniMaxM2SparseMoeBlock not available in this version of transformers",
).MiniMaxM2SparseMoeBlock


@requires_cadence("weekly")
@pytest.mark.parametrize("model_stub", ["hf-internal-testing/MiniMax-M2-Small"])
def test_calib_replace_minimax_m2_all_experts(model_stub):
with skip_weights_download():
model = AutoModelForCausalLM.from_pretrained(model_stub)

with contextlib.ExitStack() as stack:
stack.enter_context(calibration_forward_context(model))
stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True))

moe_layer = None
for _, module in model.named_modules():
if isinstance(module, CalibrationMiniMaxM2SparseMoeBlock):
moe_layer = module
break

assert moe_layer is not None

num_experts = moe_layer.experts.num_experts
seen_gate = [False for _ in range(num_experts)]
seen_down = [False for _ in range(num_experts)]
gate_ptrs = {
moe_layer.experts.gate_up_proj[i].data_ptr(): i for i in range(num_experts)
}
down_ptrs = {
moe_layer.experts.down_proj[i].data_ptr(): i for i in range(num_experts)
}

original_linear = F.linear

def patched_linear(input, weight, *args, **kwargs):
ptr = weight.data_ptr()
if ptr in gate_ptrs:
seen_gate[gate_ptrs[ptr]] = True
if ptr in down_ptrs:
seen_down[down_ptrs[ptr]] = True
return original_linear(input, weight, *args, **kwargs)

# Create dummy input tensor that simulates hidden_states
hidden_dim = model.config.hidden_size
batch, seq_len = 2, 8
sample = torch.randn(
batch,
seq_len,
hidden_dim,
dtype=moe_layer.experts.gate_up_proj.dtype,
device=moe_layer.experts.gate_up_proj.device,
)

with torch.no_grad():
F.linear = patched_linear # patch only within this scope
try:
_ = moe_layer(sample)
finally:
F.linear = original_linear

assert all(seen_gate), f"Not all experts were run (gate_up): {seen_gate}"
assert all(seen_down), f"Not all experts were run (down_proj): {seen_down}"


@requires_gpu
def test_calib_minimax_m2_module():
config = MiniMaxM2Config(
hidden_size=16,
intermediate_size=8,
num_hidden_layers=1,
num_attention_heads=4,
num_key_value_heads=1,
head_dim=4,
max_position_embeddings=64,
num_experts_per_tok=2,
num_local_experts=4,
)
with torch.device("cuda"):
original = MiniMaxM2SparseMoeBlock(config).eval()

hidden_dim = config.hidden_size
sample = torch.randn(2, 4, hidden_dim, device="cuda")

with calibration_forward_context(original):
true_output = original(sample)

module = CalibrationMiniMaxM2SparseMoeBlock(original, config, True)
with calibration_forward_context(module):
output = module(sample)
assert torch.allclose(true_output, output, atol=1e-5)

module = CalibrationMiniMaxM2SparseMoeBlock(original, config, False)
with calibration_forward_context(module):
output = module(sample)
assert torch.allclose(true_output, output, atol=1e-5)


def test_calib_minimax_m2_uses_upstream_experts_when_not_calibrating_all():
config = MiniMaxM2Config(
hidden_size=16,
intermediate_size=8,
num_hidden_layers=1,
num_attention_heads=4,
num_key_value_heads=1,
head_dim=4,
max_position_embeddings=64,
num_experts_per_tok=2,
num_local_experts=4,
)
original = MiniMaxM2SparseMoeBlock(config).eval()
module = CalibrationMiniMaxM2SparseMoeBlock(original, config, False)

sample = torch.randn(2, 4, config.hidden_size)

with calibration_forward_context(original):
true_output = original(sample)

with mock.patch.object(
module.experts, "forward", wraps=module.experts.forward
) as mocked_forward:
with calibration_forward_context(module):
output = module(sample)

assert mocked_forward.call_count == 1
assert torch.allclose(true_output, output, atol=1e-5)
Loading