diff --git a/examples/quantizing_moe/minimax_m2_example.py b/examples/quantizing_moe/minimax_m2_example.py new file mode 100644 index 0000000000..7eababbb9c --- /dev/null +++ b/examples/quantizing_moe/minimax_m2_example.py @@ -0,0 +1,104 @@ +import torch +from datasets import load_dataset +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modeling.minimax_m2 import ( # noqa: F401 + CalibrationMiniMaxM2SparseMoeBlock, +) +from llmcompressor.modifiers.awq import AWQMapping, AWQModifier + +# Load the model +model_id = "ludovicoYIN/MiniMax-M2-BF16" +config = AutoConfig.from_pretrained(model_id) +model = AutoModelForCausalLM.from_pretrained( + model_id, dtype=torch.bfloat16, config=config +) +tokenizer = AutoTokenizer.from_pretrained(model_id) +# MoE calibration is handled automatically by the pipeline. +# The `CalibrationMiniMaxM2SparseMoeBlock` modules (from +# `llmcompressor.modeling.minimax_m2`) will be applied during calibration to enable +# proper expert calibration. These replace the original +# `MiniMaxM2SparseMoeBlock` class from +# `transformers.models.minimax_m2.modeling_minimax_m2`. + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +moe_ignores = [ + "lm_head", + "re:.*block_sparse_moe.gate$", +] + +# Experts live under `model.layers.*.block_sparse_moe.experts..(w1|w2|w3)`. +EXPERT_TARGET_REGEX = [ + "re:.*block_sparse_moe\\.experts\\.\\d+\\.w1$", + "re:.*block_sparse_moe\\.experts\\.\\d+\\.w2$", + "re:.*block_sparse_moe\\.experts\\.\\d+\\.w3$", +] + + +recipe = AWQModifier( + targets=EXPERT_TARGET_REGEX, + scheme="W4A16", + ignore=moe_ignores, + mappings=[ + AWQMapping( + "re:.*post_attention_layernorm$", + ["re:.*w1$", "re:.*w3$"], + ), + AWQMapping("re:.*w3$", ["re:.*w2$"]), + ], +) + +# Apply algorithms. +oneshot( + model=model, + dataset=ds, + processor=tokenizer, + recipe=recipe, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + max_seq_length=MAX_SEQUENCE_LENGTH, + sequential_targets=["MiniMaxM2DecoderLayer"], +) + +# Save to disk compressed. +SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modeling/__init__.py b/src/llmcompressor/modeling/__init__.py index d7cd3f24ae..f19089e750 100644 --- a/src/llmcompressor/modeling/__init__.py +++ b/src/llmcompressor/modeling/__init__.py @@ -13,6 +13,7 @@ from .deepseek_v3 import CalibrationDeepseekV3MoE # noqa: F401 from .glm4_moe import CalibrationGlm4MoeMoE # noqa: F401 from .llama4 import SequentialLlama4TextMoe # noqa: F401 +from .minimax_m2 import CalibrationMiniMaxM2SparseMoeBlock # noqa: F401 from .qwen3_moe import CalibrationQwen3MoeSparseMoeBlock # noqa: F401 from .qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock # noqa: F401 from .qwen3_next_moe import CalibrationQwen3NextSparseMoeBlock # noqa: F401 diff --git a/src/llmcompressor/modeling/minimax_m2.py b/src/llmcompressor/modeling/minimax_m2.py new file mode 100644 index 0000000000..07ac1cb0ee --- /dev/null +++ b/src/llmcompressor/modeling/minimax_m2.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +import torch.nn.functional as F + +from llmcompressor.modeling.moe_context import MoECalibrationModule + +if TYPE_CHECKING: + from transformers import MiniMaxM2Config + from transformers.models.minimax_m2.modeling_minimax_m2 import ( + MiniMaxM2SparseMoeBlock, + ) + + +@MoECalibrationModule.register("MiniMaxM2SparseMoeBlock") +class CalibrationMiniMaxM2SparseMoeBlock(MoECalibrationModule): + """Calibration module for MiniMaxM2SparseMoeBlock with all-expert calibration.""" + + is_permanent = False + + def __init__( + self, + original: MiniMaxM2SparseMoeBlock, + config: MiniMaxM2Config, + calibrate_all_experts: bool = True, + ): + super().__init__() + + # Gating + self.calibrate_all_experts = calibrate_all_experts + + # Extract submodules directly to prevent parameter duplication + # in find_tied_parameters (caused by holding self.original) + self.gate = original.gate + self.experts = original.experts + + # MiniMax specific parameters + self.jitter_noise = original.jitter_noise + self.num_experts = config.num_local_experts + self.top_k = original.top_k + # Use unbound function so this module's buffers are used. + self._route_tokens_to_experts = type(original).route_tokens_to_experts + self.register_buffer( + "e_score_correction_bias", original.e_score_correction_bias + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Forward pass with optional all-expert calibration mode. + + - `calibrate_all_experts=False`: use upstream expert execution path. + - `calibrate_all_experts=True`: execute every expert on all tokens, + then aggregate only routed-token outputs. + """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_( + 1.0 - self.jitter_noise, 1.0 + self.jitter_noise + ) + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + if self.e_score_correction_bias.device != router_logits.device: + self.e_score_correction_bias = self.e_score_correction_bias.to( + router_logits.device + ) + top_k_index, top_k_weights = self._route_tokens_to_experts(self, router_logits) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + expert_mask = F.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + + for expert_idx, expert_layer in enumerate(self.experts): + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + + if self.calibrate_all_experts: + expert_out = expert_layer(hidden_states)[token_idx] + else: + expert_out = expert_layer(hidden_states[token_idx]) + + if token_idx.numel() > 0: + expert_weights = top_k_weights[token_idx, top_k_pos] + weighted_output = expert_out * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_( + 0, + token_idx, + weighted_output.to(hidden_states.dtype), + ) + + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ) + + return final_hidden_states, router_logits + + def restore(self, original: torch.nn.Module) -> torch.nn.Module: + return original diff --git a/tests/llmcompressor/modeling/test_calib_minimax_m2.py b/tests/llmcompressor/modeling/test_calib_minimax_m2.py new file mode 100644 index 0000000000..636f98e2a5 --- /dev/null +++ b/tests/llmcompressor/modeling/test_calib_minimax_m2.py @@ -0,0 +1,125 @@ +import contextlib +import importlib +from functools import lru_cache, partial + +import pytest +import torch +from transformers import AutoConfig + +from llmcompressor.modeling.minimax_m2 import CalibrationMiniMaxM2SparseMoeBlock +from llmcompressor.modeling.moe_context import moe_calibration_context +from llmcompressor.utils.helpers import calibration_forward_context +from tests.testing_utils import requires_cadence + + +@lru_cache(maxsize=1) +def _load_minimax_remote_classes(): + """ + Load MiniMax M2 classes from the official HF repo via trust_remote_code. + """ + config = AutoConfig.from_pretrained("MiniMaxAI/MiniMax-M2", trust_remote_code=True) + modeling_module_name = config.__class__.__module__.replace( + "configuration_minimax_m2", "modeling_minimax_m2" + ) + modeling_module = importlib.import_module(modeling_module_name) + return ( + config.__class__, + modeling_module.MiniMaxM2SparseMoeBlock, + modeling_module.MiniMaxM2ForCausalLM, + ) + + +def _make_tiny_minimax_config(config_cls): + return config_cls( + vocab_size=256, + hidden_size=16, + intermediate_size=32, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=1, + head_dim=4, + max_position_embeddings=64, + num_experts_per_tok=2, + num_local_experts=4, + router_jitter_noise=0.0, + ) + + +def _assert_outputs_close(reference, candidate, atol=1e-5): + if isinstance(reference, tuple): + assert isinstance(candidate, tuple) + assert len(reference) == len(candidate) + for ref_tensor, cand_tensor in zip(reference, candidate): + assert torch.allclose(ref_tensor, cand_tensor, atol=atol) + else: + assert torch.allclose(reference, candidate, atol=atol) + + +@requires_cadence("weekly") +def test_calib_replace_minimax_m2_all_experts(): + try: + config_cls, _, model_cls = _load_minimax_remote_classes() + except Exception as exc: + pytest.skip(f"Unable to load MiniMax remote modeling: {exc}") + + model = model_cls(_make_tiny_minimax_config(config_cls)).eval() + + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) + + moe_layer = None + for _, module in model.named_modules(): + if isinstance(module, CalibrationMiniMaxM2SparseMoeBlock): + moe_layer = module + break + + assert moe_layer is not None + + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] + + def hook_fn(i, module, input, output): + expert_triggered[i] = True + + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) + + hidden_dim = model.config.hidden_size + sample = torch.randn(2, 8, hidden_dim, dtype=torch.float32) + + with torch.no_grad(): + _ = moe_layer(sample) + + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" + + +def test_calib_minimax_m2_module(): + try: + config_cls, sparse_moe_block_cls, _ = _load_minimax_remote_classes() + except Exception as exc: + pytest.skip(f"Unable to load MiniMax remote modeling: {exc}") + + config = _make_tiny_minimax_config(config_cls) + original = sparse_moe_block_cls(config).eval() + + sample = torch.randn(2, 4, config.hidden_size) + + with calibration_forward_context(original): + true_output = original(sample) + + module = CalibrationMiniMaxM2SparseMoeBlock( + original, config, calibrate_all_experts=True + ).eval() + with calibration_forward_context(module): + output = module(sample) + _assert_outputs_close(true_output, output) + + module = CalibrationMiniMaxM2SparseMoeBlock( + original, config, calibrate_all_experts=False + ).eval() + with calibration_forward_context(module): + output = module(sample) + _assert_outputs_close(true_output, output)