Skip to content

Add DeepSeekV3 support for spinquant#2465

Open
carrot-o0o wants to merge 1 commit intovllm-project:mainfrom
carrot-o0o:feat/spinquant_deepseekv3
Open

Add DeepSeekV3 support for spinquant#2465
carrot-o0o wants to merge 1 commit intovllm-project:mainfrom
carrot-o0o:feat/spinquant_deepseekv3

Conversation

@carrot-o0o
Copy link

@carrot-o0o carrot-o0o commented Mar 11, 2026

SUMMARY:
added support for DeepseekV3 on spinquant

  • Currently only implemented for R1,R2,R4 and transform_type="hadamard". If this PR seems appealing, i am willing to add support for R3 and other transformation types.
  • Temporarily patched DeepseekV3TopkRouter as nn.Linearduring transform, since compressed_tensors only supports nn.Linear
  • Patched partial_hadamard that applies transformation to only a "part" of a weight matrix, since DeepseekV3 has kv_proj weights merged, and R4 only needs to apply to v_proj.
  • This is my first PR, so i apologize in advance if this PR lacking in any way. Any kind of feedback is welcome.

TEST PLAN:
I tested this code snippet to ensure spinning R1,R2,R4 results in sane outputs. kanana-2-30b-a3b-instruct-2601's model_type is deepseek_v3 but has fewer layers, making it faster to test.

from compressed_tensors.offload import dispatch_model
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.modifiers.transform import SpinQuantModifier

# Select model and load it.
MODEL_ID = "kakaocorp/kanana-2-30b-a3b-instruct-2601"

model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# NOTE: currently only rotations R1, R2, and R4 are available
# R3 and learned R1/R2 rotations will be added in a future release.
# Configure the quantization algorithm to run.
#   * apply spinquant transforms to model to reduce quantization loss
#   * quantize the weights to 4 bit with group size 128
recipe = [
    SpinQuantModifier(
        rotations=["R1", "R2", "R4"],
        transform_block_size=128,
        transform_type="hadamard",
    ),
    # QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
]

# Apply algorithms.
oneshot(model=model, recipe=recipe, pipeline="datafree")

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_model(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
    model.device
)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquantR1R2R4"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)

output results

========== SAMPLE GENERATION ==============
<|begin_of_text|>Hello my name is Emily and I am a 3rd year student at the University of Hertfordshire studying BSc (Hons) Forensic Science. I am currently on a placement year and I am working at the University of Hertfordshire as a laboratory demonstrator. I have a passion for science and I am keen to share my knowledge with others. I have previous experience in tutoring and I am able to adapt my teaching style to suit the needs of the student. I am patient and understanding and I am confident
==========================================

@github-actions
Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

@carrot-o0o carrot-o0o changed the title Add DeepSeekV3 model for spinquant Add DeepSeekV3 support for spinquant Mar 11, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly extends the spinquant modifier's compatibility by integrating support for DeepSeekV3 models. It addresses the architectural nuances of DeepSeekV3, particularly its combined KV projections, by introducing a specialized partial Hadamard transformation. This ensures that spinquant can effectively apply R1, R2, and R4 rotations to DeepSeekV3, enhancing its quantization capabilities for this model family.

Highlights

  • DeepSeekV3 Support: Added support for DeepSeekV3 models within spinquant, specifically for R1, R2, and R4 rotations using the Hadamard transformation.
  • Partial Hadamard Transformation: Introduced a partial_hadamard transformation type to handle models like DeepSeekV3 where KV projection weights are merged, allowing transformations to be applied only to specific parts (e.g., V-proj).
  • Dynamic Norm Mappings: Implemented dynamic norm mapping generation for DeepSeekV3 to correctly identify and apply transformations to its unique dense and MoE layer structures.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • src/llmcompressor/modifiers/transform/spinquant/init.py
    • Imported partial_hadamard module.
  • src/llmcompressor/modifiers/transform/spinquant/base.py
    • Modified _create_r2_scheme to conditionally adjust head_dim for R2 rotation based on combined KV projections.
    • Added logic in on_start to set qk_nope_head_dim and v_head_dim for PartialHadamardFactory when KV projections are combined.
    • Updated _create_r1_scheme to dynamically include or exclude attn_v targets based on whether attn_v_is_kv_combined is true.
    • Changed _create_r2_scheme to use partial_hadamard transform type if attn_v_is_kv_combined is true and transform_type is "hadamard".
  • src/llmcompressor/modifiers/transform/spinquant/mappings.py
    • Added attn_v_is_kv_combined field to SpinQuantMapping to indicate combined K+V linear layers.
    • Defined and registered _deepseek_mapping for DeepseekV3ForCausalLM, specifying its unique attention and MLP projection patterns, including attn_v_is_kv_combined=True.
  • src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py
    • Updated NORM_MAPPING_REGISTRY to accept callable functions for generating norm mappings.
    • Implemented _build_deepseek_v3_mappings to dynamically create norm mappings for DeepSeekV3, distinguishing between dense and MoE layers.
    • Modified infer_norm_mapping_from_model to handle callable entries in the registry.
  • src/llmcompressor/modifiers/transform/spinquant/partial_hadamard.py
    • Added a new file defining PartialHadamardFactory and PartialHadamardTransform.
    • These classes enable applying Hadamard transformations only to specific "V" rows within combined KV weight matrices, using qk_nope_head_dim and v_head_dim for precise targeting.
Activity
  • The pull request was opened by carrot-o0o.
  • The author provided a detailed summary of the changes, including specific implementation details and temporary patches.
  • A comprehensive test plan with a Python code snippet and example output was included to demonstrate the functionality and verify sane outputs for R1, R2, R4 rotations on a DeepSeekV3-like model.
  • The author acknowledged this as their first PR and invited feedback.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for DeepseekV3 models within the spinquant modifier by adding a partial_hadamard transform for combined KV weights, updating mapping configurations, and dynamically building norm mappings based on model architecture. Review comments highlight several critical issues: a potential race condition due to non-thread-safe global state modification in PartialHadamardFactory, a Denial of Service vulnerability from maliciously large regexes in _build_deepseek_v3_mappings, and another Denial of Service vulnerability from possible ZeroDivisionErrors in apply_partial_transform_weight. Additionally, there are opportunities to improve code maintainability by refactoring duplicated TransformArgs creation and removing redundant assignments in PartialHadamardTransform's constructor.

Comment on lines +172 to +175
PartialHadamardFactory.qk_nope_head_dim = (
state.model.config.qk_nope_head_dim
)
PartialHadamardFactory.v_head_dim = state.model.config.v_head_dim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The SpinQuantModifier.on_start method sets model-specific configuration values (qk_nope_head_dim, v_head_dim) as class-level attributes on PartialHadamardFactory. This introduces a race condition in multi-threaded environments, as these globally shared attributes can be overwritten, leading to incorrect transformations and corrupted model weights. This non-thread-safe global state modification is a significant design issue. Consider alternatives like a context manager or threading.local to avoid this.

Comment on lines +49 to +54
k = getattr(model.config, "first_k_dense_replace", 1)
# regex matching dense layers: 0, 1, ..., k-1
dense_re = "|".join(str(i) for i in range(k))
# regex matching MoE layers: k, k+1, ...
# matches any number that is NOT one of the dense layer indices
moe_re = rf"(?!({'|'.join(str(i) for i in range(k))})(?:\D|$))\d+"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The _build_deepseek_v3_mappings function constructs regular expressions by joining a range of integers up to k, where k is retrieved from the model configuration (model.config.first_k_dense_replace). If an attacker provides a model with a maliciously large value for this configuration parameter, it can lead to the creation of extremely large regular expressions. Compiling and matching such regexes can consume excessive CPU and memory, leading to a Denial of Service (DoS).

Comment on lines +165 to +174
num_chunks = value.shape[0] // chunk_size
result = value.clone()

scale = torch.tensor(v_head_dim, dtype=torch.float64).sqrt()
for i in range(num_chunks):
start_idx = i * chunk_size + qk_nope_head_dim
end_idx = (i + 1) * chunk_size
result[start_idx:end_idx, :] = (
_multihead_matmul(transform_weight.T, value[start_idx:end_idx, :])
/ scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The apply_partial_transform_weight function performs division by chunk_size and scale, where chunk_size is the sum of qk_nope_head_dim and v_head_dim, and scale is the square root of v_head_dim. These values are retrieved from the model configuration. If an attacker provides a model with a configuration where qk_nope_head_dim + v_head_dim == 0 or v_head_dim == 0, it will cause a ZeroDivisionError, leading to a Denial of Service.

Comment on lines +223 to +254
apply_list = []
apply_list.append(TransformArgs(
targets=[
self.mappings.embedding,
self.mappings.attn_o,
*self.mappings.mlp_out,
],
location="weight_output",
))
if getattr(self.mappings, "attn_v_is_kv_combined", False):
apply_list.append(TransformArgs(
targets=[
self.mappings.attn_q,
self.mappings.attn_k,
*self.mappings.mlp_in,
self.mappings.lm_head,
],
location="weight_input",
inverse=True,
))
else:
apply_list.append(TransformArgs(
targets=[
self.mappings.attn_q,
self.mappings.attn_k,
self.mappings.attn_v,
*self.mappings.mlp_in,
self.mappings.lm_head,
],
location="weight_input",
inverse=True,
))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if/else block has significant code duplication for creating the TransformArgs. This can be refactored by building the targets list first, which will make the code more concise and easier to maintain.

        apply_list = [
            TransformArgs(
                targets=[
                    self.mappings.embedding,
                    self.mappings.attn_o,
                    *self.mappings.mlp_out,
                ],
                location="weight_output",
            )
        ]

        weight_input_targets = [
            self.mappings.attn_q,
            self.mappings.attn_k,
        ]
        if not getattr(self.mappings, "attn_v_is_kv_combined", False):
            weight_input_targets.append(self.mappings.attn_v)
        weight_input_targets.extend(self.mappings.mlp_in)
        weight_input_targets.append(self.mappings.lm_head)

        apply_list.append(
            TransformArgs(
                targets=weight_input_targets,
                location="weight_input",
                inverse=True,
            )
        )

Comment on lines +98 to +105
super().__init__(weight, perm, scheme, args, module_type)
self.weight = weight
self.perm = perm
self.scheme = scheme
self.args = args
self.module_type = module_type
self.qk_nope_head_dim = qk_nope_head_dim
self.v_head_dim = v_head_dim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assignments to self.weight, self.perm, self.scheme, self.args, and self.module_type are redundant as they are already handled by the super().__init__() call. You can remove them for cleaner and more maintainable code.

Suggested change
super().__init__(weight, perm, scheme, args, module_type)
self.weight = weight
self.perm = perm
self.scheme = scheme
self.args = args
self.module_type = module_type
self.qk_nope_head_dim = qk_nope_head_dim
self.v_head_dim = v_head_dim
super().__init__(weight, perm, scheme, args, module_type)
self.qk_nope_head_dim = qk_nope_head_dim
self.v_head_dim = v_head_dim

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant