-
Notifications
You must be signed in to change notification settings - Fork 438
[MoE] MiniMax-M2/M2.1 calibration follow-up #2335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
adaf58f
63f25ba
d205050
41865b8
ea45d08
2fed675
4b6e476
fc922b9
149d383
0ce22bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| from datasets import load_dataset | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modeling.minimax_m2 import ( # noqa: F401 | ||
| CalibrationMiniMaxM2SparseMoeBlock, | ||
| ) | ||
| from llmcompressor.modifiers.awq import AWQMapping, AWQModifier | ||
|
|
||
| # Load the model | ||
| model_id = "MiniMaxAI/MiniMax-M2" | ||
| model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto") | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
| # MoE calibration is handled automatically by the pipeline. | ||
| # The `CalibrationMiniMaxM2SparseMoeBlock` modules (from | ||
| # `llmcompressor.modeling.minimax_m2`) will be applied during calibration to enable | ||
| # proper expert calibration. These replace the original | ||
| # `MiniMaxM2SparseMoeBlock` class from | ||
| # `transformers.models.minimax_m2.modeling_minimax_m2`. | ||
|
|
||
| # Select calibration dataset. | ||
| DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
| DATASET_SPLIT = "train_sft" | ||
|
|
||
| # Select number of samples. 512 samples is a good place to start. | ||
| # Increasing the number of samples can improve accuracy. | ||
| NUM_CALIBRATION_SAMPLES = 512 | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
|
|
||
| # Load dataset and preprocess. | ||
| ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | ||
| ds = ds.shuffle(seed=42) | ||
|
|
||
|
|
||
| def preprocess(example): | ||
| return { | ||
| "text": tokenizer.apply_chat_template( | ||
| example["messages"], | ||
| tokenize=False, | ||
| ) | ||
| } | ||
|
|
||
|
|
||
| ds = ds.map(preprocess) | ||
|
|
||
|
|
||
| # Tokenize inputs. | ||
| def tokenize(sample): | ||
| return tokenizer( | ||
| sample["text"], | ||
| padding=False, | ||
| max_length=MAX_SEQUENCE_LENGTH, | ||
| truncation=True, | ||
| add_special_tokens=False, | ||
| ) | ||
|
|
||
|
|
||
| ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
|
||
| moe_ignores = [ | ||
| # MoE gate layers are sensitive to quantization. | ||
| "re:.*mlp.gate$", | ||
| # Ignore the output head. | ||
| "lm_head", | ||
| ] | ||
|
|
||
| # Configure the quantization algorithm to run. | ||
| recipe = AWQModifier( | ||
| targets="Linear", | ||
| scheme="W4A16", | ||
| ignore=moe_ignores, | ||
| mappings=[ | ||
| AWQMapping( | ||
| "re:.*input_layernorm$", | ||
| ["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"], | ||
| ) | ||
| ], | ||
| ) | ||
|
|
||
|
|
||
| # Apply algorithms. | ||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| sequential_targets=["MiniMaxM2DecoderLayer"], | ||
| ) | ||
|
|
||
| # Save to disk compressed. | ||
| SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128" | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| from llmcompressor.modeling.moe_context import MoECalibrationModule | ||
|
|
||
| if TYPE_CHECKING: | ||
| from transformers import MiniMaxM2Config | ||
| from transformers.models.minimax_m2.modeling_minimax_m2 import ( | ||
| MiniMaxM2SparseMoeBlock, | ||
| ) | ||
|
|
||
|
|
||
| @MoECalibrationModule.register("MiniMaxM2SparseMoeBlock") | ||
| class CalibrationMiniMaxM2SparseMoeBlock(MoECalibrationModule): | ||
| """ | ||
| Calibration version of MiniMaxM2SparseMoeBlock that can send all tokens | ||
| to all experts during calibration. | ||
|
|
||
| When `calibrate_all_experts=True`, each expert is executed on all tokens so | ||
| quantization statistics are collected for every expert, while routed-token | ||
| weighting is still used for the final output. | ||
| """ | ||
|
|
||
| is_permanent = False | ||
|
|
||
| def __init__( | ||
| self, | ||
| original: MiniMaxM2SparseMoeBlock, | ||
| config: MiniMaxM2Config, | ||
| calibrate_all_experts: bool = True, | ||
| ): | ||
| super().__init__() | ||
| self.config = config | ||
| self.experts = original.experts | ||
| self.gate = original.gate | ||
| self.calibrate_all_experts = calibrate_all_experts | ||
| self.jitter_noise = original.jitter_noise | ||
| self.register_buffer( | ||
| "e_score_correction_bias", original.e_score_correction_bias | ||
| ) | ||
|
|
||
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Forward pass with optional all-expert calibration mode. | ||
|
|
||
| - `calibrate_all_experts=False`: use upstream expert execution path. | ||
| - `calibrate_all_experts=True`: execute every expert on all tokens, | ||
| then aggregate only routed-token outputs. | ||
| """ | ||
| batch_size, sequence_length, hidden_dim = hidden_states.shape | ||
| if self.training and self.jitter_noise > 0: | ||
| hidden_states *= torch.empty_like(hidden_states).uniform_( | ||
| 1.0 - self.jitter_noise, 1.0 + self.jitter_noise | ||
| ) | ||
| hidden_states = hidden_states.view(-1, hidden_dim) | ||
| _router_logits, top_k_weights, top_k_index = self.gate( | ||
| hidden_states, self.e_score_correction_bias | ||
| ) | ||
|
|
||
| if not self.calibrate_all_experts: | ||
| final_hidden_states = self.experts( | ||
| hidden_states, top_k_index, top_k_weights | ||
| ) | ||
| return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) | ||
|
|
||
| # Reimplementing MiniMaxM2Experts.forward only when calibrating all experts. | ||
| final_hidden_states = torch.zeros( | ||
| (batch_size * sequence_length, hidden_dim), | ||
| dtype=hidden_states.dtype, | ||
| device=hidden_states.device, | ||
| ) | ||
| expert_mask = F.one_hot(top_k_index, num_classes=self.experts.num_experts) | ||
| expert_mask = expert_mask.permute(2, 1, 0) | ||
|
|
||
| for expert_idx in range(self.experts.num_experts): | ||
| top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) | ||
| has_tokens = token_idx.numel() > 0 | ||
|
|
||
| # Run all tokens through the expert to gather stats. | ||
| gate, up = F.linear( | ||
| hidden_states, self.experts.gate_up_proj[expert_idx] | ||
| ).chunk(2, dim=-1) | ||
| expert_out_full = self.experts.act_fn(gate) * up | ||
| expert_out_full = F.linear( | ||
| expert_out_full, self.experts.down_proj[expert_idx] | ||
| ) | ||
| if not has_tokens: | ||
| continue | ||
| expert_out = expert_out_full[token_idx] | ||
|
|
||
| expert_weights = top_k_weights[token_idx, top_k_pos] | ||
| weighted_output = expert_out * expert_weights.unsqueeze(-1) | ||
| final_hidden_states.index_add_( | ||
| 0, | ||
| token_idx, | ||
| weighted_output.to(hidden_states.dtype), | ||
| ) | ||
|
|
||
| final_hidden_states = final_hidden_states.reshape( | ||
| batch_size, sequence_length, hidden_dim | ||
| ) | ||
|
|
||
| return final_hidden_states | ||
|
|
||
|
Comment on lines
+49
to
+100
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| def restore(self, original: torch.nn.Module) -> torch.nn.Module: | ||
| return original | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,145 @@ | ||
| import contextlib | ||
| from unittest import mock | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.nn.functional as F | ||
| from transformers import AutoModelForCausalLM | ||
|
|
||
| from llmcompressor.modeling.minimax_m2 import CalibrationMiniMaxM2SparseMoeBlock | ||
| from llmcompressor.modeling.moe_context import moe_calibration_context | ||
| from llmcompressor.utils.dev import skip_weights_download | ||
| from llmcompressor.utils.helpers import calibration_forward_context | ||
| from tests.testing_utils import requires_cadence, requires_gpu | ||
|
|
||
| MiniMaxM2Config = pytest.importorskip( | ||
| "transformers.models.minimax_m2.configuration_minimax_m2", | ||
| reason="MiniMaxM2Config not available in this version of transformers", | ||
| ).MiniMaxM2Config | ||
| MiniMaxM2SparseMoeBlock = pytest.importorskip( | ||
| "transformers.models.minimax_m2.modeling_minimax_m2", | ||
| reason="MiniMaxM2SparseMoeBlock not available in this version of transformers", | ||
| ).MiniMaxM2SparseMoeBlock | ||
|
|
||
|
|
||
| @requires_cadence("weekly") | ||
| @pytest.mark.parametrize("model_stub", ["hf-internal-testing/MiniMax-M2-Small"]) | ||
| def test_calib_replace_minimax_m2_all_experts(model_stub): | ||
| with skip_weights_download(): | ||
| model = AutoModelForCausalLM.from_pretrained(model_stub) | ||
|
|
||
| with contextlib.ExitStack() as stack: | ||
| stack.enter_context(calibration_forward_context(model)) | ||
| stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) | ||
|
|
||
| moe_layer = None | ||
| for _, module in model.named_modules(): | ||
| if isinstance(module, CalibrationMiniMaxM2SparseMoeBlock): | ||
| moe_layer = module | ||
| break | ||
|
|
||
| assert moe_layer is not None | ||
|
|
||
| num_experts = moe_layer.experts.num_experts | ||
| seen_gate = [False for _ in range(num_experts)] | ||
| seen_down = [False for _ in range(num_experts)] | ||
| gate_ptrs = { | ||
| moe_layer.experts.gate_up_proj[i].data_ptr(): i for i in range(num_experts) | ||
| } | ||
| down_ptrs = { | ||
| moe_layer.experts.down_proj[i].data_ptr(): i for i in range(num_experts) | ||
| } | ||
|
|
||
| original_linear = F.linear | ||
|
|
||
| def patched_linear(input, weight, *args, **kwargs): | ||
| ptr = weight.data_ptr() | ||
| if ptr in gate_ptrs: | ||
| seen_gate[gate_ptrs[ptr]] = True | ||
| if ptr in down_ptrs: | ||
| seen_down[down_ptrs[ptr]] = True | ||
| return original_linear(input, weight, *args, **kwargs) | ||
|
|
||
| # Create dummy input tensor that simulates hidden_states | ||
| hidden_dim = model.config.hidden_size | ||
| batch, seq_len = 2, 8 | ||
| sample = torch.randn( | ||
| batch, | ||
| seq_len, | ||
| hidden_dim, | ||
| dtype=moe_layer.experts.gate_up_proj.dtype, | ||
| device=moe_layer.experts.gate_up_proj.device, | ||
| ) | ||
|
|
||
| with torch.no_grad(): | ||
| F.linear = patched_linear # patch only within this scope | ||
| try: | ||
| _ = moe_layer(sample) | ||
| finally: | ||
| F.linear = original_linear | ||
|
|
||
| assert all(seen_gate), f"Not all experts were run (gate_up): {seen_gate}" | ||
| assert all(seen_down), f"Not all experts were run (down_proj): {seen_down}" | ||
|
|
||
|
|
||
| @requires_gpu | ||
| def test_calib_minimax_m2_module(): | ||
| config = MiniMaxM2Config( | ||
| hidden_size=16, | ||
| intermediate_size=8, | ||
| num_hidden_layers=1, | ||
| num_attention_heads=4, | ||
| num_key_value_heads=1, | ||
| head_dim=4, | ||
| max_position_embeddings=64, | ||
| num_experts_per_tok=2, | ||
| num_local_experts=4, | ||
| ) | ||
| with torch.device("cuda"): | ||
| original = MiniMaxM2SparseMoeBlock(config).eval() | ||
|
|
||
| hidden_dim = config.hidden_size | ||
| sample = torch.randn(2, 4, hidden_dim, device="cuda") | ||
|
|
||
| with calibration_forward_context(original): | ||
| true_output = original(sample) | ||
|
|
||
| module = CalibrationMiniMaxM2SparseMoeBlock(original, config, True) | ||
| with calibration_forward_context(module): | ||
| output = module(sample) | ||
| assert torch.allclose(true_output, output, atol=1e-5) | ||
|
|
||
| module = CalibrationMiniMaxM2SparseMoeBlock(original, config, False) | ||
| with calibration_forward_context(module): | ||
| output = module(sample) | ||
| assert torch.allclose(true_output, output, atol=1e-5) | ||
|
|
||
|
|
||
| def test_calib_minimax_m2_uses_upstream_experts_when_not_calibrating_all(): | ||
| config = MiniMaxM2Config( | ||
| hidden_size=16, | ||
| intermediate_size=8, | ||
| num_hidden_layers=1, | ||
| num_attention_heads=4, | ||
| num_key_value_heads=1, | ||
| head_dim=4, | ||
| max_position_embeddings=64, | ||
| num_experts_per_tok=2, | ||
| num_local_experts=4, | ||
| ) | ||
| original = MiniMaxM2SparseMoeBlock(config).eval() | ||
| module = CalibrationMiniMaxM2SparseMoeBlock(original, config, False) | ||
|
|
||
| sample = torch.randn(2, 4, config.hidden_size) | ||
|
|
||
| with calibration_forward_context(original): | ||
| true_output = original(sample) | ||
|
|
||
| with mock.patch.object( | ||
| module.experts, "forward", wraps=module.experts.forward | ||
| ) as mocked_forward: | ||
| with calibration_forward_context(module): | ||
| output = module(sample) | ||
|
|
||
| assert mocked_forward.call_count == 1 | ||
| assert torch.allclose(true_output, output, atol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
weighted_outputtensor should already have the same dtype ashidden_states, as it's derived from operations on tensors that originate fromhidden_statesandtop_k_weights. Therefore, the.to(hidden_states.dtype)call appears to be redundant and can be removed for a minor cleanup.