Skip to content

[Perf] DSV3.2 Indexer Fused Weights Projection#38684

Open
benchislett wants to merge 1 commit intovllm-project:mainfrom
CentML:perf/deepseek-fused-wk
Open

[Perf] DSV3.2 Indexer Fused Weights Projection#38684
benchislett wants to merge 1 commit intovllm-project:mainfrom
CentML:perf/deepseek-fused-wk

Conversation

@benchislett
Copy link
Copy Markdown
Collaborator

Purpose

Fuse the WK and Weights_Proj projections in the DSV3.2 Indexer. This is an alternative optimization to #35968, which overlaps the projections instead of fusing them. Doing the fusion provides a greater speedup:

Benchmark timings for DSV3.2 NVFP4 on 8xB200 (TP8, No Specdec)

BS128 8k/1k

Fused WK+WeightsProj (Decode):  21.6 ms
Baseline             (Decode):  22.3 ms

BS1 8k/1k

Fused WK+WeightsProj (Decode):  9.5 ms
Overlap              (Decode):  9.8  ms
Baseline             (Decode):  10.2  ms

Fused WK+WeightsProj (TTFT):  339 ms
Overlap              (TTFT):  340 ms
Baseline             (TTFT):  341 ms

Testing

GSM8k shows a slight decrease.

PR (2 runs):

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9500|±  | 0.006|
|     |       |strict-match    |     5|exact_match|↑  |0.9492|±  | 0.006|

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9553|±  |0.0057|
|     |       |strict-match    |     5|exact_match|↑  |0.9530|±  |0.0058|

Main (2 runs):

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9538|±  |0.0058|
|     |       |strict-match    |     5|exact_match|↑  |0.9545|±  |0.0057|

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9462|±  |0.0062|
|     |       |strict-match    |     5|exact_match|↑  |0.9462|±  |0.0062|

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@benchislett benchislett requested a review from luccafong as a code owner April 1, 2026 03:59
@mergify mergify bot added the deepseek Related to DeepSeek models label Apr 1, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fuses the "wk" and "weights_proj" layers into a single "MergedColumnParallelLinear" named "wk_weights_proj" within the DeepSeek-V2 and MTP models to optimize performance. Several critical issues were identified: forcing "quant_config=None" may lead to loading errors with quantized checkpoints, the tensor slicing in the forward pass assumes a 2D input which could fail with 3D tensors (a flattening suggestion was provided), and the weight loading logic is susceptible to name corruption and potential crashes due to substring matching and missing guard conditions.

Comment on lines +646 to 653
self.wk_weights_proj = MergedColumnParallelLinear(
hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
)
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.weights_proj = ReplicatedLinear(
hidden_size,
self.n_head,
[self.head_dim, self.n_head],
bias=False,
quant_config=None,
prefix=f"{prefix}.weights_proj",
disable_tp=True,
prefix=f"{prefix}.wk_weights_proj",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Forcing quant_config=None for wk_weights_proj will cause correctness issues when loading from quantized checkpoints (e.g., FP8). Since wk is typically quantized in DeepSeek-V3/V3.2 checkpoints, the weight_loader will attempt to bit-copy quantized weights into a non-quantized parameter without applying the necessary scales. The current weight_loader for MergedColumnParallelLinear does not handle on-the-fly dequantization. If fusion is required for performance, you must implement a custom weight loader that can dequantize wk during the loading process or ensure that the quantization configuration is correctly propagated.

Comment on lines +696 to +698
kw, _ = self.wk_weights_proj(hidden_states)
k = kw[:, : self.head_dim]
weights_raw = kw[:, self.head_dim :]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The slicing kw[:, : self.head_dim] assumes that kw is a 2D tensor. However, when using torch.compile or in certain prefill paths, hidden_states (and thus kw) can be a 3D tensor. In such cases, this slicing will produce incorrect results. It is safer to flatten hidden_states to a token-based representation before the projection, which also ensures consistency for the subsequent indexer_op call on line 733.

Suggested change
kw, _ = self.wk_weights_proj(hidden_states)
k = kw[:, : self.head_dim]
weights_raw = kw[:, self.head_dim :]
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
kw, _ = self.wk_weights_proj(hidden_states)
k = kw[:, : self.head_dim]
weights_raw = kw[:, self.head_dim :]

Comment on lines +1443 to +1447
indexer_fused_mapping = [
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
]
stacked_params_mapping.extend(indexer_fused_mapping)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The substring replacement logic in load_weights is susceptible to a bug where wk matches wk_weights_proj. If a checkpoint already contains fused weights (e.g., from a previous save), name.replace("wk", "wk_weights_proj") will corrupt the parameter name (e.g., resulting in ...wk_weights_proj_weights_proj). Additionally, this mapping is added unconditionally, which may cause crashes in non-V3.2 models where wk_weights_proj is not defined. Please add wk_weights_proj to the guard condition in the weight loading loop (around line 1503) to ensure it only attempts to map if the parameter exists in params_dict.

Comment on lines +245 to +246
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the issue in deepseek_v2.py, the wk substring match can corrupt weight names if the checkpoint is already fused. Ensure that wk_weights_proj is included in the guard condition within the load_weights loop (around line 292) to prevent incorrect mapping and potential KeyError when the parameter is missing in certain model configurations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant