[Perf] DSV3.2 Indexer Fused Weights Projection#38684
[Perf] DSV3.2 Indexer Fused Weights Projection#38684benchislett wants to merge 1 commit intovllm-project:mainfrom
Conversation
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
There was a problem hiding this comment.
Code Review
This pull request fuses the "wk" and "weights_proj" layers into a single "MergedColumnParallelLinear" named "wk_weights_proj" within the DeepSeek-V2 and MTP models to optimize performance. Several critical issues were identified: forcing "quant_config=None" may lead to loading errors with quantized checkpoints, the tensor slicing in the forward pass assumes a 2D input which could fail with 3D tensors (a flattening suggestion was provided), and the weight loading logic is susceptible to name corruption and potential crashes due to substring matching and missing guard conditions.
| self.wk_weights_proj = MergedColumnParallelLinear( | ||
| hidden_size, | ||
| self.head_dim, | ||
| bias=False, | ||
| quant_config=quant_config, | ||
| prefix=f"{prefix}.wk", | ||
| ) | ||
| self.k_norm = LayerNorm(self.head_dim, eps=1e-6) | ||
| self.weights_proj = ReplicatedLinear( | ||
| hidden_size, | ||
| self.n_head, | ||
| [self.head_dim, self.n_head], | ||
| bias=False, | ||
| quant_config=None, | ||
| prefix=f"{prefix}.weights_proj", | ||
| disable_tp=True, | ||
| prefix=f"{prefix}.wk_weights_proj", | ||
| ) |
There was a problem hiding this comment.
Forcing quant_config=None for wk_weights_proj will cause correctness issues when loading from quantized checkpoints (e.g., FP8). Since wk is typically quantized in DeepSeek-V3/V3.2 checkpoints, the weight_loader will attempt to bit-copy quantized weights into a non-quantized parameter without applying the necessary scales. The current weight_loader for MergedColumnParallelLinear does not handle on-the-fly dequantization. If fusion is required for performance, you must implement a custom weight loader that can dequantize wk during the loading process or ensure that the quantization configuration is correctly propagated.
| kw, _ = self.wk_weights_proj(hidden_states) | ||
| k = kw[:, : self.head_dim] | ||
| weights_raw = kw[:, self.head_dim :] |
There was a problem hiding this comment.
The slicing kw[:, : self.head_dim] assumes that kw is a 2D tensor. However, when using torch.compile or in certain prefill paths, hidden_states (and thus kw) can be a 3D tensor. In such cases, this slicing will produce incorrect results. It is safer to flatten hidden_states to a token-based representation before the projection, which also ensures consistency for the subsequent indexer_op call on line 733.
| kw, _ = self.wk_weights_proj(hidden_states) | |
| k = kw[:, : self.head_dim] | |
| weights_raw = kw[:, self.head_dim :] | |
| hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | |
| kw, _ = self.wk_weights_proj(hidden_states) | |
| k = kw[:, : self.head_dim] | |
| weights_raw = kw[:, self.head_dim :] |
| indexer_fused_mapping = [ | ||
| ("wk_weights_proj", "wk", 0), | ||
| ("wk_weights_proj", "weights_proj", 1), | ||
| ] | ||
| stacked_params_mapping.extend(indexer_fused_mapping) |
There was a problem hiding this comment.
The substring replacement logic in load_weights is susceptible to a bug where wk matches wk_weights_proj. If a checkpoint already contains fused weights (e.g., from a previous save), name.replace("wk", "wk_weights_proj") will corrupt the parameter name (e.g., resulting in ...wk_weights_proj_weights_proj). Additionally, this mapping is added unconditionally, which may cause crashes in non-V3.2 models where wk_weights_proj is not defined. Please add wk_weights_proj to the guard condition in the weight loading loop (around line 1503) to ensure it only attempts to map if the parameter exists in params_dict.
| ("wk_weights_proj", "wk", 0), | ||
| ("wk_weights_proj", "weights_proj", 1), |
There was a problem hiding this comment.
Similar to the issue in deepseek_v2.py, the wk substring match can corrupt weight names if the checkpoint is already fused. Ensure that wk_weights_proj is included in the guard condition within the load_weights loop (around line 292) to prevent incorrect mapping and potential KeyError when the parameter is missing in certain model configurations.
Purpose
Fuse the WK and Weights_Proj projections in the DSV3.2 Indexer. This is an alternative optimization to #35968, which overlaps the projections instead of fusing them. Doing the fusion provides a greater speedup:
Benchmark timings for DSV3.2 NVFP4 on 8xB200 (TP8, No Specdec)
BS128 8k/1k
BS1 8k/1k
Testing
GSM8k shows a slight decrease.
PR (2 runs):
Main (2 runs):