Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions examples/quantizing_moe/minimax_m2_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch
from datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modeling.minimax_m2 import ( # noqa: F401
CalibrationMiniMaxM2SparseMoeBlock,
)
from llmcompressor.modifiers.awq import AWQMapping, AWQModifier

# Load the model
model_id = "ludovicoYIN/MiniMax-M2-BF16"
config = AutoConfig.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id, dtype=torch.bfloat16, config=config
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# MoE calibration is handled automatically by the pipeline.
# The `CalibrationMiniMaxM2SparseMoeBlock` modules (from
# `llmcompressor.modeling.minimax_m2`) will be applied during calibration to enable
# proper expert calibration. These replace the original
# `MiniMaxM2SparseMoeBlock` class from
# `transformers.models.minimax_m2.modeling_minimax_m2`.

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(tokenize, remove_columns=ds.column_names)

moe_ignores = [
"lm_head",
"re:.*block_sparse_moe.gate$",
]

# Experts live under `model.layers.*.block_sparse_moe.experts.<idx>.(w1|w2|w3)`.
EXPERT_TARGET_REGEX = [
"re:.*block_sparse_moe\\.experts\\.\\d+\\.w1$",
"re:.*block_sparse_moe\\.experts\\.\\d+\\.w2$",
"re:.*block_sparse_moe\\.experts\\.\\d+\\.w3$",
]


recipe = AWQModifier(
targets=EXPERT_TARGET_REGEX,
scheme="W4A16",
ignore=moe_ignores,
mappings=[
AWQMapping(
"re:.*post_attention_layernorm$",
["re:.*w1$", "re:.*w3$"],
),
AWQMapping("re:.*w3$", ["re:.*w2$"]),
],
duo_scaling=False,
)

# Apply algorithms.
oneshot(
model=model,
dataset=ds,
processor=tokenizer,
recipe=recipe,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
max_seq_length=MAX_SEQUENCE_LENGTH,
sequential_targets=["MiniMaxM2DecoderLayer"],
)

# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
5 changes: 5 additions & 0 deletions src/llmcompressor/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
from .deepseek_v3 import CalibrationDeepseekV3MoE # noqa: F401
from .glm4_moe import CalibrationGlm4MoeMoE # noqa: F401
from .llama4 import SequentialLlama4TextMoe # noqa: F401

try: # Optional dependency: transformers may not include minimax_m2 yet.
from .minimax_m2 import CalibrationMiniMaxM2SparseMoeBlock # noqa: F401
except (ImportError, ModuleNotFoundError): # pragma: no cover
pass
from .qwen3_moe import CalibrationQwen3MoeSparseMoeBlock # noqa: F401
from .qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock # noqa: F401
from .qwen3_next_moe import CalibrationQwen3NextSparseMoeBlock # noqa: F401
Expand Down
99 changes: 99 additions & 0 deletions src/llmcompressor/modeling/minimax_m2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import torch
import torch.nn.functional as F
from llmcompressor.modeling.moe_context import MoECalibrationModule

if TYPE_CHECKING:
from transformers import MiniMaxM2Config
from transformers.models.minimax_m2.modeling_minimax_m2 import (
MiniMaxM2SparseMoeBlock,
)


@MoECalibrationModule.register("MiniMaxM2SparseMoeBlock")
class CalibrationMiniMaxM2SparseMoeBlock(MoECalibrationModule):
"""Calibration module for MiniMaxM2SparseMoeBlock that supports calibrating all experts."""

is_permanent = False

def __init__(
self,
original: MiniMaxM2SparseMoeBlock,
config: MiniMaxM2Config,
calibrate_all_experts: bool = True,
):
super().__init__()

# Gating
self.calibrate_all_experts = calibrate_all_experts

# Extract submodules directly to prevent parameter duplication
# in find_tied_parameters (caused by holding self.original)
self.gate = original.gate
self.experts = original.experts

# MiniMax specific parameters
self.jitter_noise = original.jitter_noise
self.num_experts = config.num_local_experts
self.top_k = original.top_k
# Use unbound function so this module's buffers are used.
self._route_tokens_to_experts = type(original).route_tokens_to_experts
self.register_buffer(
"e_score_correction_bias", original.e_score_correction_bias
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Forward pass with optional all-expert calibration mode.

- `calibrate_all_experts=False`: use upstream expert execution path.
- `calibrate_all_experts=True`: execute every expert on all tokens,
then aggregate only routed-token outputs.
"""
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
)
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(hidden_states)
if self.e_score_correction_bias.device != router_logits.device:
self.e_score_correction_bias = self.e_score_correction_bias.to(router_logits.device)
top_k_index, top_k_weights = self._route_tokens_to_experts(self, router_logits)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
expert_mask = F.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)

for expert_idx, expert_layer in enumerate(self.experts):
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])

if self.calibrate_all_experts:
expert_out = expert_layer(hidden_states)[token_idx]
else:
expert_out = expert_layer(hidden_states[token_idx])

if token_idx.numel() > 0:
expert_weights = top_k_weights[token_idx, top_k_pos]
weighted_output = expert_out * expert_weights.unsqueeze(-1)
final_hidden_states.index_add_(
0,
token_idx,
weighted_output.to(hidden_states.dtype),
)

final_hidden_states = final_hidden_states.reshape(
batch_size, sequence_length, hidden_dim
)

return final_hidden_states, router_logits

def restore(self, original: torch.nn.Module) -> torch.nn.Module:
return original
6 changes: 5 additions & 1 deletion src/llmcompressor/utils/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from loguru import logger
from safetensors.torch import save_file
from transformers import AutoModelForCausalLM, PreTrainedModel
from transformers.modeling_utils import TORCH_INIT_FUNCTIONS

try:
from transformers.modeling_utils import TORCH_INIT_FUNCTIONS
except ImportError: # transformers>=5 moved this
from transformers.initialization import TORCH_INIT_FUNCTIONS
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME

__all__ = [
Expand Down
145 changes: 145 additions & 0 deletions tests/llmcompressor/modeling/test_calib_minimax_m2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import contextlib
from unittest import mock

import pytest
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM

from llmcompressor.modeling.minimax_m2 import CalibrationMiniMaxM2SparseMoeBlock
from llmcompressor.modeling.moe_context import moe_calibration_context
from llmcompressor.utils.dev import skip_weights_download
from llmcompressor.utils.helpers import calibration_forward_context
from tests.testing_utils import requires_cadence, requires_gpu

MiniMaxM2Config = pytest.importorskip(
"transformers.models.minimax_m2.configuration_minimax_m2",
reason="MiniMaxM2Config not available in this version of transformers",
).MiniMaxM2Config
MiniMaxM2SparseMoeBlock = pytest.importorskip(
"transformers.models.minimax_m2.modeling_minimax_m2",
reason="MiniMaxM2SparseMoeBlock not available in this version of transformers",
).MiniMaxM2SparseMoeBlock


@requires_cadence("weekly")
@pytest.mark.parametrize("model_stub", ["hf-internal-testing/MiniMax-M2-Small"])
def test_calib_replace_minimax_m2_all_experts(model_stub):
with skip_weights_download():
model = AutoModelForCausalLM.from_pretrained(model_stub)

with contextlib.ExitStack() as stack:
stack.enter_context(calibration_forward_context(model))
stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True))

moe_layer = None
for _, module in model.named_modules():
if isinstance(module, CalibrationMiniMaxM2SparseMoeBlock):
moe_layer = module
break

assert moe_layer is not None

num_experts = moe_layer.experts.num_experts
seen_gate = [False for _ in range(num_experts)]
seen_down = [False for _ in range(num_experts)]
gate_ptrs = {
moe_layer.experts.gate_up_proj[i].data_ptr(): i for i in range(num_experts)
}
down_ptrs = {
moe_layer.experts.down_proj[i].data_ptr(): i for i in range(num_experts)
}

original_linear = F.linear

def patched_linear(input, weight, *args, **kwargs):
ptr = weight.data_ptr()
if ptr in gate_ptrs:
seen_gate[gate_ptrs[ptr]] = True
if ptr in down_ptrs:
seen_down[down_ptrs[ptr]] = True
return original_linear(input, weight, *args, **kwargs)

# Create dummy input tensor that simulates hidden_states
hidden_dim = model.config.hidden_size
batch, seq_len = 2, 8
sample = torch.randn(
batch,
seq_len,
hidden_dim,
dtype=moe_layer.experts.gate_up_proj.dtype,
device=moe_layer.experts.gate_up_proj.device,
)

with torch.no_grad():
F.linear = patched_linear # patch only within this scope
try:
_ = moe_layer(sample)
finally:
F.linear = original_linear

assert all(seen_gate), f"Not all experts were run (gate_up): {seen_gate}"
assert all(seen_down), f"Not all experts were run (down_proj): {seen_down}"


@requires_gpu
def test_calib_minimax_m2_module():
config = MiniMaxM2Config(
hidden_size=16,
intermediate_size=8,
num_hidden_layers=1,
num_attention_heads=4,
num_key_value_heads=1,
head_dim=4,
max_position_embeddings=64,
num_experts_per_tok=2,
num_local_experts=4,
)
with torch.device("cuda"):
original = MiniMaxM2SparseMoeBlock(config).eval()

hidden_dim = config.hidden_size
sample = torch.randn(2, 4, hidden_dim, device="cuda")

with calibration_forward_context(original):
true_output = original(sample)

module = CalibrationMiniMaxM2SparseMoeBlock(original, config, True)
with calibration_forward_context(module):
output = module(sample)
assert torch.allclose(true_output, output, atol=1e-5)

module = CalibrationMiniMaxM2SparseMoeBlock(original, config, False)
with calibration_forward_context(module):
output = module(sample)
assert torch.allclose(true_output, output, atol=1e-5)


def test_calib_minimax_m2_uses_upstream_experts_when_not_calibrating_all():
config = MiniMaxM2Config(
hidden_size=16,
intermediate_size=8,
num_hidden_layers=1,
num_attention_heads=4,
num_key_value_heads=1,
head_dim=4,
max_position_embeddings=64,
num_experts_per_tok=2,
num_local_experts=4,
)
original = MiniMaxM2SparseMoeBlock(config).eval()
module = CalibrationMiniMaxM2SparseMoeBlock(original, config, False)

sample = torch.randn(2, 4, config.hidden_size)

with calibration_forward_context(original):
true_output = original(sample)

with mock.patch.object(
module.experts, "forward", wraps=module.experts.forward
) as mocked_forward:
with calibration_forward_context(module):
output = module(sample)

assert mocked_forward.call_count == 1
assert torch.allclose(true_output, output, atol=1e-5)