Skip to content

[New Model][Nvidia] Add SM12x support for DeepSeek V4 Flash with essential fixes#41834

Open
jasl wants to merge 109 commits into
vllm-project:mainfrom
jasl:codex/ds4-sm120-min-enable
Open

[New Model][Nvidia] Add SM12x support for DeepSeek V4 Flash with essential fixes#41834
jasl wants to merge 109 commits into
vllm-project:mainfrom
jasl:codex/ds4-sm120-min-enable

Conversation

@jasl

@jasl jasl commented May 6, 2026

Copy link
Copy Markdown
Contributor

Summary

This PR enables DeepSeek V4 Flash on SM120/SM121 Blackwell client hardware by carrying the SM12x fallback and tuning stack needed for the current vLLM V1 path. It is intended for RTX PRO 6000 Blackwell Workstation Edition, RTX 5090-class SM120, and GB10 / DGX Spark SM121 users who cannot use SM100-only TMEM / tcgen05 kernels.

Duplicate-work check

Open PR search was refreshed on 2026-06-12 for SM120 / SM12x / DeepSeek V4 / GB10 terms. The nearest open PRs are related but not duplicates:

PR Difference
#43477 Draft branch for DeepSeek V4 + GLM-5.1 using FlashInfer SM120 sparse MLA and DeepGEMM SM120 / MXFP4 dependency branches. This PR keeps the SM12x fallback/tuning path and validation surface for users who need the current vLLM branch without that external branch stack.
#40929 Earlier WIP Triton fallback effort. This PR is the maintained replacement branch with the broader scheduler, prefix-cache, parser, quant, warmup, and harness-validated fixes carried forward.
#42856 Focused workspace-bound fix that explicitly depends on / references this PR; it is a subset-style bugfix, not the full DeepSeek V4 SM12x enablement branch.

Fixed preview tags

These tags are in jasl/vllm and give users stable pins while the PR is still moving:

Tag Commit Use
sm120-pr-41834-stable-preview-20260612075245 f32247a5a695fa8979d61837bf6b87da897dcb7d validated rebased PR branch preview
sm120-pr-41834-fallback-before-replacement-20260612053720 5d1584e2de2b3c64540e70dfc370b0211eb6b2fc fallback tag for the old PR head before branch replacement
sm120-pr-41834-before-warmup-push-20260617 5be22eb0e23a7513b73fbdfd2b3c18b033a385eb prior head before the 2026-06-17 rebase + wedge fix (regressed instruction-following, see jasl#19)
sm120-pr-41834-stable-preview-20260617 73e99c16599061eea4561406c70fff3d2130a4ee rebased-onto-upstream head with the long-prefill wedge fix

Update 2026-06-17 — rebased onto upstream + long-prefill wedge fix

Rebased the branch onto current upstream (b2cfae777, 105 commits) and added a fix for the first-long-prefill wedge surfaced during PR testing: the D512-split sparse-MLA prefill Triton kernels JIT-compiled inside the engine step on the first long request (~20s), parking EngineCore in shm_broadcast -> "sample_tokens RPC timed out". The fix pre-compiles them during the DeepSeek-V4 sparse-MLA warmup (new env VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_WARMUP, default on); synthetic warm only, no inference-path change, clean no-op when unreachable.

Validated on 2x RTX PRO 6000 Blackwell (SM120) and 2-node GB10 / DGX Spark (SM121), DeepSeek-V4-Flash, fp8 KV, MTP=2.

Wedge fix and correctness

First long-prefill in-inference JIT of the D512-split prefill kernels (the wedge), warmup off vs on:

Platform warmup off warmup on (fix)
RTX SM120 12 0
GB10 SM121 -- 0

GSM8K-500 (5-shot, greedy temp 0):

Config flexible strict
RTX SM120, decode gate ON 0.946 0.918
RTX SM120, decode gate OFF 0.930 0.904
GB10 SM121 (prior session) 0.950 0.925

Instruction-following (jasl#19 repro, exact prompt, temp 0): clean JSON-only output across samples on the default path -- the prior-head regression is absent.

Gated SM120 decode optimization -- requires FlashInfer git-main

The VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE=1 decode optimization requires FlashInfer git-main (>=0.6.13), which provides flashinfer.mla._sparse_mla_sm120 (absent from the 0.6.12 release). Install/update it before enabling the gate:

pip install --upgrade "flashinfer-python @ git+https://github.com/flashinfer-ai/flashinfer.git"
# if flashinfer-cubin lags the wheel: export FLASHINFER_DISABLE_VERSION_CHECK=1

RTX SM120, decode gate ON vs OFF, ctx0 decode (aggregate tok/s, 0 errors all rows):

C gate OFF gate ON gain
1 189.7 201.4 +6%
2 311.3 334.7 +8%
4 483.1 531.6 +10%
8 707.6 801.9 +13%
16 990.5 1164.7 +18%
32 1545.0 1849.6 +20%
64 2132.5 2814.8 +32%

gate-ON @C64 = 2814.8 tok/s matches the community target (~2815). RTX gate-ON 59K long-context decode (aggregate tok/s, 0 errors): 214.5 / 230.6 / 437.8 / 657.3 / 1056.6 / 965.6 / 2243.7 at C = 1/2/4/8/16/32/64.

GB10 SM121, decode gate ON, MTP=2 vs nomtp (vllm bench serve, 1024-in / 1024-out):

C nomtp tok/s mtp2 tok/s MTP gain nomtp median TPOT ms mtp2 median TPOT ms
1 26.8 37.9 +41% 36.98 25.68
4 67.5 79.5 +18% 57.28 47.85
8 94.2 109.9 +17% 82.02 69.68
16 127.6 147.4 +15% 120.35 104.93

(MTP both raises throughput and lowers per-output-token latency; the gap narrows as concurrency rises.)

The default path needs no FlashInfer update. With the gate off (default), the import is lazy/gated, so FlashInfer 0.6.12 (official) works unchanged. The long-prefill wedge fix is fully orthogonal (imports no FlashInfer). On GB10 / 2-node, also pin nvidia-nccl-cu13==2.30.7 (a rebuild reverts it; a per-node mismatch hangs the NCCL handshake).

Branch validation, 2026-06-12

Base and head:

  • upstream base: 8a91228dbe363d1d113deb2a82e289429130dd01
  • PR head: f32247a5a695fa8979d61837bf6b87da897dcb7d
  • branch range: 96 commits over upstream/main

Commands run on the final head:

Command Result
git diff --check upstream/main...HEAD pass
DCO scan over upstream/main..HEAD pass; every commit has Signed-off-by
VLLM_TARGET_DEVICE=empty .venv/bin/python -m compileall -q vllm/envs.py vllm/model_executor/warmup/kernel_warmup.py vllm/models/deepseek_v4 vllm/v1/core vllm/v1/attention/backends/mla vllm/reasoning/deepseek_v4_reasoning_parser.py tests/test_envs.py tests/v1/core/test_prefix_caching.py tests/v1/core/test_scheduler.py tests/reasoning/test_deepseekv4_reasoning_parser.py tests/quantization/test_sm12x_tuned_config_lookup.py pass
.venv/bin/python -m pytest tests/test_envs.py::test_deepseek_v4_sparse_mla_stats_path_env -q on the remote vLLM environment 1 passed, 16 warnings
python3 -m pytest tests/test_scripts.py -q in the public harness 128 passed in 14.41s

Local vLLM pytest/ruff were not run on the Mac checkout because its .venv does not currently include torch or ruff. GPU-path validation remains remote SM120/SM121-only.

Latest clean SM120 RTX PRO 6000 x2 data, 2026-06-12

Artifact roots:

  • artifacts/codex_pr_stable_preview_f32247a/2x_rtx_pro_6000_sm120/rtx_current_pr_short_throughput_mtp_noep_20260612084721
  • artifacts/codex_pr_stable_preview_f32247a/2x_rtx_pro_6000_sm120/rtx_current_pr_clean_mtp_noep_20260612080629

Short-throughput profile:

  • TP=2, MTP=2, expert parallel off, FP8 KV, block size 256.
  • max_model_len=131072, gpu_memory_utilization=0.975, max_num_batched_tokens=4096, max_num_seqs=24.
  • Prefix cache disabled, FULL_AND_PIECEWISE, 80 prompts per concurrency.
  • Phase exits: server_startup=0, bench_hf_mt_bench=0, bench_random_prefill_sweep=0.
  • Regression check: output/input throughput ratios are against the previous accepted same-profile EP-off reference; all are above the 0.95 floor.

HF MT-bench, 80 prompts:

C output tok/s ratio vs reference mean TTFT ms p99 ITL ms MTP acceptance %
1 180.94 1.009 49.59 13.08 68.36
2 284.53 1.003 70.04 32.35 68.19
4 427.10 0.999 82.70 38.83 68.25
8 600.33 1.005 110.97 86.19 67.91
16 840.46 1.019 156.73 86.50 67.34
24 987.77 1.030 209.05 86.71 68.20

Random prefill sweep, C=1, output length 128, 8 requests per case:

Prompt / output tokens input tok/s ratio vs reference mean TTFT ms requests
4K / 128 3123.74 0.996 660.21 8 / 8
16K / 128 6209.00 1.005 2030.49 8 / 8
64K / 128 7049.72 0.999 8715.51 8 / 8

Correctness and reliability profile:

  • TP=2, MTP=2, expert parallel off, FP8 KV, prefix cache disabled, max_model_len=131072, max_num_seqs=4, max_num_batched_tokens=4096.
  • Phase exits: server_startup=0, bench_hf_mt_bench=0, eval_gsm8k=0, bench_random_prefill_sweep=0, bench_random_8000x1000=0, bench_random_256x256=0.
  • Post-run current-boot driver scan found no Xid, UVM, NV_ERR, GPU-lost, illegal-access, unspecified-launch, or fatal GPU signals; no vLLM compute processes were left running.

GSM8K 5-shot, limit-200, /v1/completions, MTP=2, concurrency 4:

Metric Value Floor Result
flexible exact match 0.965 0.940 pass
strict exact match 0.940 0.925 pass

Additional 128K-profile random checks:

Shape C output tok/s mean TTFT ms p99 ITL ms MTP acceptance %
8K / 1K 1 130.93 1367.03 13.44 52.56
8K / 1K 2 191.19 1586.64 17.44 50.28
8K / 1K 4 260.72 1666.96 199.75 51.76
256 / 256 1 153.07 88.80 13.17 51.46
256 / 256 4 369.86 127.80 84.44 52.50

Latest clean GB10 / SM121 data, 2026-06-12

Artifact root:

  • artifacts/codex_pr_stable_preview_f32247a/2x_gb10_sm121/gb10_forum53_mtp2_epoff_c2_gmem0685_mml81920/20260612074113

Profile:

  • TP=2, MTP=2, expert parallel off, FP8 KV, block size 256.
  • max_model_len=81920, max_num_seqs=2, max_num_batched_tokens=4096, gpu_memory_utilization=0.685.
  • Prefix cache enabled; Forum Refactor attention kernels #53 C=2 shape: forum53_c2:2:2:3200:256.
  • This covers the 80K-token prompt case on the final PR head. Failed, interrupted, or driver-signal artifacts are intentionally excluded from this PR body.

Gate result:

Gate Result
summary ok true
serve_start.exit_code 0
streaming_pressure.exit_code 0
driver health ok=true, signal count 0
request failures 0 / 4
preemptions 0

Timing and runtime summary:

Metric Value
max prompt tokens 80,127
max TTFT 124.045698 s
max elapsed 124.949141 s
avg inter-chunk latency 0.056711 s
p95 inter-chunk latency 0.064278 s
p99 inter-chunk latency 0.144954 s
max inter-chunk latency 0.144954 s
GPU KV usage avg / max 65.81% / 86.40%
prefix-cache hits / queries 79,872 / 3,444,165

AI assistance disclosure

AI assistants, including OpenAI Codex/GPT models and Anthropic Claude models, were used for code review, refactoring support, regression-script writing, and benchmark analysis. The branch was validated through human review plus the commands and harness artifacts listed above.

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added deepseek Related to DeepSeek models nvidia v1 labels May 6, 2026
@jasl

jasl commented May 6, 2026

Copy link
Copy Markdown
Contributor Author

@zyongye
I've cleaned up the old PR, could you help review this one?

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for DeepSeek V4 on SM12x (Blackwell) architectures by providing Triton-based fallbacks for DeepGEMM-dependent operations. Key enhancements include the introduction of specialized Triton kernels for sparse MLA, FP8 einsum, and MQA logits, as well as memory optimizations in the sparse attention indexer to compute top-k indices without materializing full logits. Additionally, the PR updates the model loader to support weight name filtering for skipping MTP weights and handles Blackwell-specific FP8 quantization scales. I have no feedback to provide.

@chatgpt-codex-connector

Copy link
Copy Markdown

💡 Codex Review

def _sparse_indexer_requires_deep_gemm() -> bool:
return current_platform.is_cuda() and not (
current_platform.is_device_capability_family(120)
)

P1 Badge Keep DeepGEMM requirement for SM120 FP4 indexer path

This helper now disables the DeepGEMM requirement for every SM120 run, but the FP4 indexer cache path still depends on DeepGEMM kernels (fp8_fp4_*) because the new SM120 fallback only handles q_scale is None (FP8 Q). With use_fp4_cache=True on SM120 and no DeepGEMM installed, construction succeeds and the first prefill/decode call fails at runtime with the DeepGEMM _missing() error instead of being rejected up front.


if self.load_config.load_format == "fastsafetensors":
weights_iterator = fastsafetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)

P2 Badge Propagate weight_name_filter to fast safetensor loaders

The new pre-load weight_name_filter is only wired into safetensors_weights_iterator; this branch still loads all tensors for fastsafetensors (and similarly other non-default safetensor iterators), so skipped tensors are still materialized. For DeepSeek V4 this defeats the intended early skip of MTP weights and can reintroduce high transient memory use/OOM when these load formats are enabled.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@jasl jasl changed the title [New Model][Nvidia] Add SM12x support for DeepSeek V4 Flash [New Model][Nvidia] Add SM12x support for DeepSeek V4 Flash with essential fixes May 6, 2026
@jasl jasl force-pushed the codex/ds4-sm120-min-enable branch from 042e366 to df2e6f8 Compare May 6, 2026 16:26
jasl and others added 30 commits June 18, 2026 06:49
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Blocks taken from the free pool are about to receive new content, so their old prefix-cache hash must be cleared even if the cache map no longer contains that exact block entry. This avoids stale block hashes accumulating under sustained prefix-cache reuse.

Based on vllm-project#44237.

Co-authored-by: Oxygen <1391083091@qq.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Upstream vllm-project#44914 carries the runtime fix. Keep a local regression test for the DeepSeek V4 MoE runner refactor path so RoutedExperts continues to use MXFP4 expert quantization.

Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
(cherry picked from commit 90bf557)
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
(cherry picked from commit 134f3a6)
DeepSeek-V4-Flash-NVFP4 (NVIDIA modelopt) sets expert_dtype=fp4 but its
MoE experts are NVFP4 (weight_scale_2 + input_scale), not MXFP4.
DeepseekV4FP8Config previously always used Mxfp4MoEMethod for fp4
experts, causing KeyError: experts.w13_input_scale on GB10 (sm_121).

Detect moe_quant_algo==NVFP4 and use the existing
ModelOptNvFp4FusedMoE for experts while keeping FP8 block for
linear/attn. Adjust weights mapper to not apply the MXFP4 .scale
rename to NVFP4 per-expert keys.

(cherry picked from commit bf6a74e)
(cherry picked from commit a34fff5da17603438d443dd0a2c35430694bff37)
Signed-off-by: jasl <jasl9187@hotmail.com>
The flashinfer_cutlass NVFP4 MoE already supports SM12x (family 100|120) and
applies the SwiGLU clamp (populates gemm1_clamp_limit, passes swiglu_limit to
flashinfer_cutlass_fused_moe), but NVFP4_BACKENDS_WITH_CLAMP listed only
FLASHINFER_TRTLLM (SM100-only). DeepSeek-V4-Flash-NVFP4 (swiglu_limit=10.0) was
therefore unservable on SM12x. Add FLASHINFER_CUTLASS to the clamp set so
--moe-backend flashinfer_cutlass serves NVFP4 on RTX SM120 / GB10 SM121 with no
FlashInfer upgrade. Verified: loads (77.95 GiB), serves, FLASHINFER_CUTLASS selected.

Signed-off-by: jasl <jasl9187@hotmail.com>
(cherry picked from commit 401d3174c3e8dc8eddf9f54ab0ea3cbb693bd5b9)
…-project#45061)

Replace the fixed PREFILL_CHUNK_SIZE chunking + batch-wide workspace bound
in the SM12x sparse-MLA prefill path with upstream's adaptive
get_prefill_chunk_plan (vllm-project#45061): pack as many requests as
fit the workspace-area bound per chunk and allocate the kv workspace
per-chunk at this chunk's compressed+gather width (chunk_M) instead of the
batch-wide worst case. Preserves the SM12x indexed-D512 split/chunked
prefill paths and the Triton sparse-MLA dispatch.

Signed-off-by: jasl <jasl9187@hotmail.com>
The rebase onto upstream/main surfaced 9 mypy errors in the sparse-MLA
decode kernels where decode_swa_lens / decode_swa_indices / seq_lens
(typed torch.Tensor | None) are indexed without a None-guard. They were
uncaught because git rebase skips pre-commit hooks. The fields are
unconditionally populated when num_decode_tokens > 0 (the only path that
reaches the decode kernels), so assert-guard them.

Signed-off-by: jasl <jasl9187@hotmail.com>
git rebase replays commits without running pre-commit hooks, so the
rebased branch carried pre-existing type-safety gaps and ruff-format
drift in our changed files (surfaced under the newer upstream config).

Fixes:
- kernel_warmup: drop the phantom _disable_sparse_mla_prefill_stats
  reference (the symbol was never defined anywhere; the stats-disable
  wrapper collapses to a direct warmup call) and type: ignore the
  intentional SimpleNamespace warmup batch for apply_grammar_bitmask
- single_type_kv_cache_manager: access MLA-spec subtype attributes via
  getattr (mypy-safe, identical runtime)
- llm_base_proposer: assert the draft temperature tensor is non-None
  (it is used unconditionally below)
- fused_moe: annotate config_file_paths: list[str]
- test_deepseek_v4_mega_moe: annotate the mixed-type calls list
- ruff-format drift across the touched files

No runtime behavior change except the warmup phantom, which previously
raised ImportError when VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH was set.

Signed-off-by: jasl <jasl9187@hotmail.com>
A dev-branch sparse-MLA stats diagnostic leaked into the PR: the only
reader was the phantom _disable_sparse_mla_prefill_stats warmup wrapper
(removed in the prior cleanup, as it referenced a never-defined symbol).
With that gone the env has no production reader, so remove its envs.py
declaration + lookup and the dedicated test that exercised it.

Signed-off-by: jasl <jasl9187@hotmail.com>
Wire FlashInfer PR3395's packed SM120 sparse-MLA decode kernel into the
DeepSeek V4 FlashMLA attention as an env-gated decode override (default
off). The kernel ships in official flashinfer >= 0.6.13; we drive it
through its low-level _SparseMLAPagedAttentionRunner rather than the
public trtllm_batch_decode_sparse_mla_dsv4 wrapper.

Root cause of the C8-C64 ctx0 decode gap versus the FlashMLA decode path
is this decode kernel (the prior PR3395 reintegration ported only the
packed prefill). Holding everything else fixed (MARLIN MoE, packed
fp8_ds_mla cache, source tree, MTP2) and swapping only the decode kernel
lifts ctx0 decode throughput on dual RTX PRO 6000 / SM120, in128/out512:

  C    default(Triton)   this    delta
  8        542            582    +7%
  16       771            833    +8%
  32       790            981   +24%
  64      1345           1683   +25%

GSM8K 5-shot limit-300 is correctness-neutral (flexible 0.953 / strict
0.927, matching the MXFP4 baseline).

Why the low-level runner and not the public wrapper: the wrapper's
_sparse_mla_decode_workspace returns no scratch when num_tokens > 64, so
it allocates mid_out/mid_lse (hundreds of MB) fresh on every decode step.
The MTP multi-query decode shape routinely exceeds 64 tokens (C32/C64),
making the wrapper a regression (-17 to -20% vs the FlashMLA path). The
runner instead takes graph-stable mid_out/mid_lse reserved once from the
vLLM workspace manager and reused every step; that cached scratch is the
entire win (a decode-shaped autotune pass over the kernel's
chunks_per_block tactic added 0% on top, so it is not included).

- New DeepseekV4FlashInferSM120Attention(DeepseekV4FlashMLAAttention)
  overrides only _forward_decode; reuses the packed cache, sparse-index
  metadata, and packed prefill. The compressed decode index is forced
  contiguous (the kernel asserts it).
- Gated by VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE, SM12x, and
  has_flashinfer_trtllm_sparse_mla_dsv4(); default off keeps the FlashMLA
  decode path byte-for-byte.

Signed-off-by: jasl <jasl9187@hotmail.com>
… startup

The first long prefill JIT-compiled the D512-split sparse-MLA prefill Triton
kernels mid-engine-step (~20s), parking EngineCore in shm_broadcast and
surfacing as a "sample_tokens RPC timed out" wedge under concurrency.

Pre-compile them during the DeepSeek-V4 sparse-MLA warmup over the complete
128-aligned combined_topk specialization set [256..1152], gated by a new env
VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_WARMUP (default on). Synthetic
throwaway tensors only (no workspace manager use, no state leak); cleanly
no-ops when the split path is unreachable. No inference-path behavior change.

Signed-off-by: jasl <jasl9187@hotmail.com>
Port the dropped "Align DeepSeek V4 API semantics" layer (a874655, on
ds4-sm120-full but never on the PR line) onto the PR head 73e99c1:
- top-level `thinking` request field (DeepSeek OpenAI-compat) -> chat-template
  enable_thinking, via apply_chat_template_kwargs at the protocol boundary;
- bare DeepSeek-V4 requests (no thinking key) now default thinking ON, matching
  ds4-sm120-full (the PR line currently defaults them OFF);
- deepseek_v4_sampling_override (apply DeepSeek's official sampling defaults
  when thinking is enabled; per-request opt-out);
- reasoning_content alias on ChatMessage/DeltaMessage; prefix/wo_eos message
  fields; tool-call empty-arguments robustness in the DSv4 tokenizer.

Addresses the common bare-request instruction-following regression in
#19: preview-dev defaulted bare requests to thinking OFF, so the model
answered directly and prepended explanatory prose despite "output ONLY a JSON
array"; full defaulted bare -> thinking ON, reasoned, then complied. (The
reporter's EXPLICIT enable_thinking=false case is a separate residual softness,
not fixed by this change.)

Cherry-picked from a874655; conflicts resolved by keeping the June PR's
evolved code (build_chat_params reasoning_effort handling; the dedicated
DeepSeekV4ReasoningParser) and layering the DSv4 semantics on top
(serving._effective_chat_template_kwargs applies apply_chat_template_kwargs
after build_chat_params).

Signed-off-by: jasl <jasl9187@hotmail.com>
…tion

top_k_per_row_prefill writes its output as a contiguous [M, select_k] buffer
(it receives the logits strides, not the output's). The indexer passes
out[:, :select_k], which is non-contiguous whenever the compressed-KV count is
below the top-k width -- i.e. for short prompts and the early queries of long
prompts. Writing it as contiguous silently corrupts the later rows' top-k
(all -1), so the C4A sparse-MLA prefill drops the distant/downsampled context
and attends only the recent sliding window. This degrades instruction
following (returns prose instead of the requested JSON) and garbles
long-context generation under concurrent traffic.

Hand the op a contiguous work buffer and copy the result back; this is a no-op
when the slice is already contiguous (select_k == top-k width), so behavior is
unchanged outside the corrupted case. The chunked long-context path was already
safe (stride-aware torch.topk / torch.gather).

Signed-off-by: jasl <jasl9187@hotmail.com>
The non-contiguous-output fix used selected.contiguous(), which copies the
slice's current contents (the -1 placeholders just written by out.fill_(-1))
into the work buffer. top_k_per_row_prefill then overwrites every element, so
that copy is wasted. Allocate an uninitialized contiguous buffer via
selected.new_empty(selected.shape) instead; behavior is unchanged (copy-back
still lands the result in the strided slice), one elementwise pass saved on
the short-prompt / early-query path.

Signed-off-by: jasl <jasl9187@hotmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: No status
Status: No status

Development

Successfully merging this pull request may close these issues.