-
Notifications
You must be signed in to change notification settings - Fork 432
feat: add Qwen3.5 MoE calibration module #2383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Sehyo
wants to merge
12
commits into
vllm-project:main
Choose a base branch
from
Sehyo:feat/qwen3-5-moe-calibration
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+916
−3
Open
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
2ab3874
feat: add Qwen3.5 MoE calibration module for quantization
Sehyo e04cc8c
Update src/llmcompressor/modeling/qwen3_5_moe.py
Sehyo be275c3
test: add calibration test for Qwen3.5 MoE module
Sehyo d6c7bca
fix: persist regex ignore patterns in saved config.json
Sehyo 122f10c
fix: copy missing processor configs in Qwen3.5 MoE example
Sehyo 09474ab
fix: add compat shims for newer transformers versions
Sehyo 202a2bd
feat: preserve MTP weights from source checkpoint during save
Sehyo c8ede1c
fix: overhaul Qwen3.5 MoE example for multimodal + production use
Sehyo 28414db
Merge branch 'main' into feat/qwen3-5-moe-calibration
HDCharles dcfdce2
fix: address PR review feedback from HDCharles
Sehyo 0fcf3ed
Merge branch 'main' into feat/qwen3-5-moe-calibration
dsikka 6ec9128
Merge branch 'main' into feat/qwen3-5-moe-calibration
dsikka File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| import os | ||
| import shutil | ||
|
|
||
| from datasets import concatenate_datasets, load_dataset | ||
| from huggingface_hub import snapshot_download | ||
| from transformers import AutoModelForImageTextToText, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
| from llmcompressor.utils import dispatch_for_generation | ||
|
|
||
| # Available Qwen3.5 MoE models (pick one): | ||
| # "Qwen/Qwen3.5-35B-A3B" | ||
| # "Qwen/Qwen3.5-122B-A10B" | ||
| # "Qwen/Qwen3.5-397B-A17B" | ||
| MODEL_ID = "Qwen/Qwen3.5-35B-A3B" | ||
|
|
||
| # Load model. | ||
| model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, dtype="auto") | ||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||
|
|
||
| # Select number of samples. 512 is recommended for production quality; | ||
| # reduce to 256 or lower for faster iteration during development. | ||
| NUM_CALIBRATION_SAMPLES = 256 | ||
| MAX_SEQUENCE_LENGTH = 4096 | ||
|
|
||
| # Load datasets and preprocess. | ||
| # Use half from each source for a diverse calibration set. | ||
| samples_per_dataset = NUM_CALIBRATION_SAMPLES // 2 | ||
|
|
||
| ds_ultrachat = load_dataset( | ||
| "HuggingFaceH4/ultrachat_200k", | ||
| split=f"train_sft[:{samples_per_dataset}]", | ||
| ) | ||
| ds_nemotron = load_dataset( | ||
| "nvidia/Nemotron-Post-Training-Dataset-v2", | ||
| split=f"chat[:{samples_per_dataset}]", | ||
| ) | ||
|
|
||
| # Both datasets share a "messages" column with the same chat format. | ||
| # Keep only that column so we can concatenate them. | ||
| ds_ultrachat = ds_ultrachat.select_columns(["messages"]) | ||
| ds_nemotron = ds_nemotron.select_columns(["messages"]) | ||
| ds = concatenate_datasets([ds_ultrachat, ds_nemotron]) | ||
| ds = ds.shuffle(seed=42) | ||
|
|
||
|
|
||
| def preprocess(example): | ||
| return { | ||
| "text": tokenizer.apply_chat_template( | ||
| example["messages"], | ||
| tokenize=False, | ||
| ) | ||
| } | ||
|
|
||
|
|
||
| ds = ds.map(preprocess) | ||
|
|
||
|
|
||
| # Tokenize inputs. | ||
| def tokenize(sample): | ||
| return tokenizer( | ||
| sample["text"], | ||
| padding=False, | ||
| max_length=MAX_SEQUENCE_LENGTH, | ||
| truncation=True, | ||
| add_special_tokens=False, | ||
| ) | ||
|
|
||
|
|
||
| ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
|
||
| # Configure the quantization algorithm and scheme. | ||
| # In this case, we: | ||
| # * quantize the weights to fp4 with per group 16 via ptq | ||
| # * calibrate a global_scale for activations, which will be used to | ||
| # quantize activations to fp4 on the fly | ||
| recipe = QuantizationModifier( | ||
| targets="Linear", | ||
| scheme="NVFP4", | ||
| ignore=[ | ||
| "lm_head", | ||
| "re:.*mlp.gate$", | ||
| "re:.*mlp.shared_expert_gate$", | ||
| "re:.*linear_attn.*", | ||
| "re:model\\.visual\\..*", | ||
| ], | ||
| ) | ||
|
|
||
| # Apply quantization. | ||
| # MoE calibration is now handled automatically by the pipeline. | ||
| # We set `moe_calibrate_all_experts` to True to ensure all experts receive | ||
| # calibration data. This temporarily updates the model definition to use | ||
| # `CalibrationQwen3_5MoeSparseMoeBlock` (from `llmcompressor.modeling.qwen3_5_moe`) | ||
| # which replaces the original `Qwen3_5MoeSparseMoeBlock` class. | ||
| # This unfuses the 3D expert parameters into individual nn.Linear modules | ||
| # so they can be targeted by quantization. | ||
| # Feel free to update the definition under | ||
| # llm-compressor/src/llmcompressor/modeling/qwen3_5_moe.py to play around with | ||
| # this behavior and evaluate its impact on quantization performance. | ||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| moe_calibrate_all_experts=True, | ||
| ) | ||
|
|
||
|
|
||
| print("\n\n") | ||
| print("========== SAMPLE GENERATION ==============") | ||
| try: | ||
| dispatch_for_generation(model) | ||
| input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to( | ||
| model.device | ||
| ) | ||
| output = model.generate(input_ids, max_new_tokens=100) | ||
| print(tokenizer.decode(output[0])) | ||
| except Exception as e: | ||
| print(f"Generation failed (non-fatal): {e}") | ||
| print("==========================================\n\n") | ||
|
|
||
|
|
||
| # Save to disk in compressed-tensors format. | ||
| # MTP (multi-token prediction) weights are automatically preserved from | ||
| # the source checkpoint during save, enabling speculative decoding with vLLM. | ||
| SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4" | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) | ||
|
|
||
| # Hot-fix: copy processor configs that save_pretrained doesn't bring over | ||
| cache_dir = snapshot_download(MODEL_ID, allow_patterns=["*.json"]) | ||
| for filename in [ | ||
| "preprocessor_config.json", | ||
| "video_preprocessor_config.json", | ||
| ]: | ||
| src = os.path.join(cache_dir, filename) | ||
| if os.path.exists(src): | ||
| shutil.copyfile(src, os.path.join(SAVE_DIR, filename)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| from llmcompressor.modeling.moe_context import MoECalibrationModule | ||
| from llmcompressor.utils.dev import skip_weights_initialize | ||
|
|
||
| if TYPE_CHECKING: | ||
| from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( | ||
| Qwen3_5MoeSparseMoeBlock, | ||
| ) | ||
|
|
||
|
|
||
| @MoECalibrationModule.register("Qwen3_5MoeSparseMoeBlock") | ||
| class CalibrationQwen3_5MoeSparseMoeBlock(MoECalibrationModule): | ||
| """ | ||
| Calibration version of Qwen3_5MoeSparseMoeBlock that unfuses 3D expert | ||
| parameters into individual MLP modules (nn.Linear) so they can be | ||
| individually quantized. Sends all tokens to all experts during calibration. | ||
|
|
||
| is_permanent = True because the unfused structure must persist for | ||
| quantization to target the individual nn.Linear expert weights. | ||
| """ | ||
|
|
||
| is_permanent = True | ||
|
|
||
| def __init__( | ||
| self, | ||
| original: Qwen3_5MoeSparseMoeBlock, | ||
| config, | ||
| calibrate_all_experts: bool = True, | ||
| ): | ||
| super().__init__() | ||
| text_config = getattr(config, "text_config", config) | ||
|
|
||
| self.num_experts = text_config.num_experts | ||
| self.top_k = text_config.num_experts_per_tok | ||
| self.hidden_size = text_config.hidden_size | ||
|
|
||
| self.calibrate_all_experts = calibrate_all_experts | ||
| self.gate = original.gate | ||
| self.shared_expert = original.shared_expert | ||
| self.shared_expert_gate = original.shared_expert_gate | ||
| self.experts = SequentialQwen3_5MoeExperts(text_config, original.experts) | ||
|
|
||
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| batch_size, sequence_length, hidden_dim = hidden_states.shape | ||
| hidden_states_reshaped = hidden_states.view(-1, hidden_dim) | ||
|
|
||
| # router: returns (router_logits, router_scores, router_indices) | ||
| _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) | ||
|
|
||
| # expert mask: (num_experts, top_k, num_tokens) | ||
| expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute( | ||
| 2, 1, 0 | ||
| ) | ||
|
|
||
| final_hidden_states = torch.zeros( | ||
| (batch_size * sequence_length, hidden_dim), | ||
| dtype=hidden_states.dtype, | ||
| device=hidden_states.device, | ||
| ) | ||
|
|
||
| for expert_idx, expert_layer in enumerate(self.experts): | ||
| idx, token_idx = torch.where(expert_mask[expert_idx]) | ||
|
|
||
| if self.calibrate_all_experts: | ||
| expert_out = expert_layer(hidden_states_reshaped)[token_idx] | ||
| else: | ||
| expert_out = expert_layer(hidden_states_reshaped[token_idx]) | ||
|
|
||
| if len(token_idx) > 0: | ||
| current_hidden_states = ( | ||
| expert_out * routing_weights[token_idx, idx, None] | ||
| ) | ||
| final_hidden_states.index_add_( | ||
| 0, | ||
| token_idx, | ||
| current_hidden_states.to(hidden_states.dtype), | ||
| ) | ||
|
|
||
| # shared expert | ||
| shared_expert_output = self.shared_expert(hidden_states_reshaped) | ||
| shared_expert_output = ( | ||
| F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) | ||
| * shared_expert_output | ||
| ) | ||
| final_hidden_states = final_hidden_states + shared_expert_output | ||
|
|
||
| final_hidden_states = final_hidden_states.reshape( | ||
| batch_size, sequence_length, hidden_dim | ||
| ) | ||
| return final_hidden_states | ||
|
|
||
| def restore(self, original: torch.nn.Module) -> torch.nn.Module: | ||
| return self | ||
|
|
||
|
|
||
| class SequentialQwen3_5MoeExperts(torch.nn.ModuleList): | ||
| """ | ||
| Unfuses 3D expert parameter tensors into individual Qwen3_5MoeMLP modules | ||
| so that each expert's weights are nn.Linear and can be targeted by | ||
| quantization with targets="Linear". | ||
| """ | ||
|
|
||
| def __init__(self, config, original): | ||
| from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( | ||
| Qwen3_5MoeMLP, | ||
| ) | ||
|
|
||
| self.num_experts = config.num_experts | ||
| intermediate_size = config.moe_intermediate_size | ||
|
|
||
| with skip_weights_initialize(): | ||
| super().__init__( | ||
| [ | ||
| Qwen3_5MoeMLP(config, intermediate_size=intermediate_size) | ||
| for _ in range(self.num_experts) | ||
| ] | ||
| ) | ||
|
|
||
| gate_up_data = original.gate_up_proj.data # [num_experts, 2*inter, hidden] | ||
| down_data = original.down_proj.data # [num_experts, hidden, inter] | ||
|
|
||
| for i in range(self.num_experts): | ||
| gate_up = gate_up_data[i] # [2*intermediate, hidden] | ||
| down = down_data[i] # [hidden, intermediate] | ||
|
|
||
| # gate_up_proj stores [gate; up] stacked along dim 0 | ||
| # nn.Linear weight is [out_features, in_features] | ||
| self[i].gate_proj.weight.data = ( | ||
| gate_up[:intermediate_size, :].clone().contiguous() | ||
| ) | ||
| self[i].up_proj.weight.data = ( | ||
| gate_up[intermediate_size:, :].clone().contiguous() | ||
| ) | ||
| self[i].down_proj.weight.data = down.clone().contiguous() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why you need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not have this in mine, and mine quanted and loaded successfully in VLLM, so would love to know as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Sehyo can you explain why this is required?