Skip to content

Hybrid KV cache support for mamba+attention models#467

Draft
malaiwah wants to merge 4 commits intollm-d:mainfrom
malaiwah:fix/hybrid-kv-offload-rebased
Draft

Hybrid KV cache support for mamba+attention models#467
malaiwah wants to merge 4 commits intollm-d:mainfrom
malaiwah:fix/hybrid-kv-offload-rebased

Conversation

@malaiwah
Copy link
Copy Markdown

Summary

Extends the llm-d fs-backend to support hybrid models like Qwen3.5 that interleave mamba and attention layers with multiple KV cache groups.

Rebased cleanly on current main (replaces #466 which had conflicts).

Changes

spec.py:

  • Per-group gpu_blocks_per_file and FileMapper instances (one subdirectory per group)
  • Supports hybrid_chunk_size from vLLM's HybridOffloadPlanner
  • Multi-group assertion replaces single-group assert len(gpu_block_size) == 1

worker.py:

  • StorageOffloadingHandlers creates per-group engines and handlers
  • _get_tensors extended for hybrid backends: handles multi-tensor state, fallback_block_size, NotImplementedError from non-standard backends
  • Canonical tensor normalization: reshapes attention tensors from the backend's kernel block size (non-deterministic) to the vLLM page size (deterministic), ensuring identical file sizes across restarts and GPU hardware. Zero-copy reshape — no data movement.

Test plan

  • Qwen3.5-4B-FP8 (4 groups: 3 mamba + 1 attention), all groups store/load to NFS
  • 99% cache hit on cold restart with canonical format
  • Deterministic file sizes across 3 consecutive restarts (no mismatches)
  • Syntax check passes

Closes #465
Related: LMCache/LMCache#2879, vllm-project/vllm#38261

AI-assisted: developed with Claude. All changes reviewed and tested by a human.

🤖 Generated with Claude Code

@github-actions
Copy link
Copy Markdown

Unsigned commits detected! Please sign your commits.

For instructions on how to set up GPG/SSH signing and verify your commits, please see GitHub Documentation.

@malaiwah malaiwah force-pushed the fix/hybrid-kv-offload-rebased branch 2 times, most recently from 200c6c5 to 3b180c5 Compare March 27, 2026 00:51
@malaiwah
Copy link
Copy Markdown
Author

Added a second commit with the C++ canonical block grouping that makes the on-disk format portable across GPUs:

C++ changes (tensor_copier.hpp/cu, storage_offload.*, storage_offload_bindings.cpp):

  • New kernel_blocks_per_canonical_block parameter in TensorCopier
  • When page_size ≠ kernel_block_size, the copy loop groups page_size // kernel_block_size kernel blocks per canonical block, making the file layout page-aligned regardless of GPU
  • Staging buffer scaled accordingly; Python binding exposes the parameter
  • worker.py passes kb_per_gb = page_size // kernel_block_size to the engine constructor

Cross-GPU validation (Creativity RTX 4080 Super → AIBoss RTX 5090, both Qwen3.5-4B-FP8):

  • Same kb_per_canonical=33 on both (page_size=1056, kernel_block_size=32)
  • AIBoss cold-reads NFS files written by Creativity: 54.6% hit rate
  • AIBoss replay: 97.6% avg cache hit — identical to Creativity
  • Output coherent (no data corruption from layout mismatch)

Both commits are GPG-signed.

…tability

Extends the llm-d fs-backend to handle hybrid models (e.g. Qwen3.5-4B-FP8
with mamba + attention layers) that have multiple KV cache groups with
different block sizes and tensor layouts.

Python changes (spec.py, worker.py):
- Per-group gpu_blocks_per_file and FileMapper instances
- Supports hybrid_chunk_size from vLLM's HybridOffloadPlanner
- StorageOffloadingHandlers creates per-group engines and handlers
- _get_tensors extended for hybrid backends (multi-tensor state,
  fallback block size)
- Multi-group dispatch via _store_handlers/_load_handlers lists

C++ changes (tensor_copier.hpp/cu, storage_offload.hpp/cpp, bindings):
- New kernel_blocks_per_canonical_block parameter in TensorCopier
- When page_size != kernel_block_size, the copy loop groups multiple
  kernel blocks into one canonical page-aligned block, making the
  on-disk format deterministic across GPUs with different kernel block
  sizes (e.g. RTX 4080 Super kb=32 vs RTX 5090 kb=64)
- worker.py passes kb_per_gb = page_size // kernel_block_size as
  kernel_blocks_per_canonical_block to the engine constructor
- staging buffer scaled by canonical factor

Cross-GPU validation (Creativity RTX 4080 Super -> AIBoss RTX 5090):
- Both use page_size=1056 / kernel_block_size=32 -> kb_per_canonical=33
- AIBoss cold-reads NFS files written by Creativity: 54.6% hit rate
- AIBoss replay: 97.6% avg cache hit (identical to Creativity)
- Output coherent — no data corruption from layout mismatch

Closes llm-d#465

Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
Co-authored-by: Claude <noreply@anthropic.com>
@malaiwah malaiwah force-pushed the fix/hybrid-kv-offload-rebased branch from cd4c991 to f21291b Compare March 27, 2026 01:55
malaiwah and others added 2 commits March 26, 2026 23:14
After adding block_offsets_list and block_counts_list parameters to
copy_blocks_via_kernels in tensor_copier.hpp, the implementation in
tensor_copier_kernels.cu was not updated, causing a C++ compile error.

The kernel path is only invoked when block_offsets_list is nullptr
(copy_blocks() forces use_kernel=false for partial-range transfers),
so the new parameters are accepted but unused in the kernel
implementation.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
Mamba/SSM recurrent state (mamba_ssm_dtype=float32) is numerically
incompatible across CUDA SM architectures.  SM 8.9 and SM 10.0 use
different CUDA kernel implementations for the SSM mixer that produce
different float32 values for identical inputs due to non-associative
parallel reductions.  Loading SM-8.9 state on SM-10.0 (or vice versa)
causes the model to output garbage (repeating tokens, word salad).

For hybrid models (mamba/linear-attention + full-attention), embed the
GPU SM version in the NFS storage path so each architecture maintains
its own isolated namespace:

  group_{i}/sm_89/...   (SM 8.9, e.g. RTX 4080 Super)
  group_{i}/sm_100/...  (SM 10.0, e.g. RTX 5090)
  group_{i}/sm_120/...  (SM 12.0, e.g. RTX 5090 Blackwell)

Same-architecture restarts still get full cache hits.  Cross-architecture
loads simply miss and recompute cleanly.

Non-hybrid models are unaffected (gpu_tag=None, paths unchanged).

Validated: cross-node test (RTX 4080 Super → RTX 5090) all 3 rounds pass
with 0/0 garbage responses after fix.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@malaiwah
Copy link
Copy Markdown
Author

Added a third commit (26d0b60) that fixes cross-GPU mamba state corruption for hybrid models running on shared NFS storage.

Root cause: Mamba/SSM recurrent state (mamba_ssm_dtype=float32) is numerically incompatible across CUDA SM architectures. SM 8.9 and SM 12.0 use different CUDA kernel implementations for the SSM mixer that produce different float32 values for identical inputs due to non-associative parallel reductions. Loading SM-8.9 state on SM-12.0 causes the model to output garbage (repeating tokens / multilingual word salad).

Fix (spec.py): For hybrid models, embed the GPU SM version in the NFS storage path so each architecture maintains an isolated namespace:

group_{i}/sm_89/...     (SM 8.9, RTX 4080 Super)
group_{i}/sm_120/...    (SM 12.0, RTX 5090)

Same-architecture restarts still get full cache hits. Cross-architecture loads simply miss and recompute cleanly. Non-hybrid models are unaffected.

Cross-node validation (Creativity RTX 4080 Super SM 8.9 → AIBoss RTX 5090 SM 12.0, Qwen3.5-4B-FP8):

  • Round 1: Creativity warm write to NFS → ✅
  • Round 2: AIBoss cold load from NFS (first access, cache miss, recomputes) → ✅ (was: '2026 2026 2026 22025...' garbage)
  • Round 3: AIBoss APC warm (hits local cache) → ✅ (was: multilingual word salad from poisoned APC)

Pre-fix failure mode confirmed and eliminated. Both containers log their SM tag at startup:

[INFO] llmd_fs_backend: Hybrid model: inserting GPU tag 'sm_89' into storage paths ...
[INFO] llmd_fs_backend: Hybrid model: inserting GPU tag 'sm_120' into storage paths ...

15 tests covering the cross-GPU KV cache isolation feature added to
SharedStorageOffloadingSpec:

TestHybridModelGPUTagPath (8 tests):
- sm_tag_format: parametrized over SM 8.9, 10.0, 12.0 — verifies
  sm_89/sm_100/sm_120 tag format (no separator)
- hybrid_path_structure: full path form {storage}/group_0/{sm_tag}/{model}/…
- all_groups_tagged: all 4 KV cache groups get the same GPU tag (critical
  for mamba group_3)
- different_sm_versions_produce_different_paths: SM 8.9 ≠ SM 12.0 paths
- sm_tag_fallback_on_cuda_error: CUDA unavailable → sm_unknown, no crash
- sm_unknown_does_not_match_real_tag: fallback path ≠ any real SM path

TestNonHybridModelNoGPUTag (4 tests):
- non-hybrid paths have no sm_ component
- non-hybrid paths are identical across GPU architectures (shared cache)

TestSpecFileMapperConsistency (3 tests):
- file_mappers count matches group count
- file_mapper (singular) is file_mappers[0]
- each group uses its own subdirectory

Tests use unittest.mock to patch OffloadingSpec.__init__ and
torch.cuda.get_device_capability — no CUDA device or running engine
required.  All 15 tests complete in ~3s.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@malaiwah
Copy link
Copy Markdown
Author

Two more commits added:

eab49a6 — test(spec): unit tests for GPU SM version path isolation

15 tests in tests/test_spec_gpu_path.py covering the cross-GPU isolation fix. Three test classes:

  • TestHybridModelGPUTagPath (8): tag format parametrized over SM 8.9/10.0/12.0, path structure, all groups tagged, cross-SM path isolation, CUDA error fallback to sm_unknown
  • TestNonHybridModelNoGPUTag (4): non-hybrid paths are GPU-architecture-agnostic
  • TestSpecFileMapperConsistency (3): FileMapper count, singlar/plural alias, per-group subdirectories

Tests mock OffloadingSpec.__init__ and torch.cuda.get_device_capability — no GPU device or running engine required. All 15 tests complete in ~3s.

Validated end-to-end:

  • Creativity (SM 8.9): [INFO] llmd_fs_backend: Hybrid model: inserting GPU tag 'sm_89' into storage paths
  • AIBoss (SM 12.0): [INFO] llmd_fs_backend: Hybrid model: inserting GPU tag 'sm_120' into storage paths
  • Cross-node test all 3 rounds ✅ — 0 garbage responses

@malaiwah
Copy link
Copy Markdown
Author

Full cross-node + cross-restart test suite: 10/10 PASS

Two-host setup: Creativity (RTX 4080 Super, SM 8.9) ↔ AIBoss (RTX 5090, SM 12.0), model lovedheart/Qwen3.5-4B-FP8 (24 mamba + 8 attention layers), both running the same multiarch container image.

Phase Description Result
A Creativity cold start: 9325-token prompt, NFS miss, correct output
B Creativity APC warm: 8448 cached tokens reused within same session
C NFS namespace: Creativity wrote 0 new sm_120 files
D AIBoss cold: sm_120 ≠ sm_89 path → NFS miss → no garbage output
E AIBoss APC warm: not poisoned by garbled cold-miss response
F Namespace isolation: sm_89/ and sm_120/ coexist under same group dir
G Creativity restart: 7392 tokens loaded from sm_89 NFS warm
H AIBoss restart: 7392 tokens loaded from sm_120 NFS warm
I LMCache hybrid detection fires in both Worker and EngineCore processes
J Concurrent cross-host requests: 97% word overlap, no interference

Phase G and H confirm that the SM-namespaced NFS format is stable across container restarts: each GPU architecture reads back exactly its own state with no cross-contamination and no file-size mismatches.

Startup log confirming the SM tags are active on both hosts:

[Creativity]  [INFO] llmd_fs_backend: Hybrid model: inserting GPU tag 'sm_89' into storage paths
[AIBoss]      [INFO] llmd_fs_backend: Hybrid model: inserting GPU tag 'sm_120' into storage paths

@malaiwah
Copy link
Copy Markdown
Author

Gentle ping on this — all CI checks are green (auto-assign ✅, signed-commits ✅) and the cross-node + cross-restart validation came back 10/10 PASS (see comment above). Happy to address any review feedback. @vMaroon @dannyharnik @kfirtoledo — would appreciate a look when you have time.

@malaiwah
Copy link
Copy Markdown
Author

Friendly ping — all CI checks are green (auto-assign ✅, signed-commits ✅) and the PR is mergeable. The only remaining gate is a code review.

Summary of what's here across the 5 commits:

  1. Multi-group hybrid KV cache support (spec.py, worker.py) — enables Qwen3.5-class models
  2. C++ canonical block grouping (TensorCopier) — GPU-portable on-disk format
  3. SM-versioned NFS namespaces (spec.py) — prevents mamba state corruption across GPU architectures
  4. Unit tests (15 tests, no GPU required, ~3s)
  5. End-to-end validation: 10/10 cross-node + cross-restart test suite PASS (RTX 4080 Super SM 8.9 ↔ RTX 5090 SM 12.0)

@vMaroon @dannyharnik @kfirtoledo — would appreciate a review when you get a chance.

@kfirtoledo
Copy link
Copy Markdown
Collaborator

kfirtoledo commented Mar 29, 2026

Thanks for the PR @malaiwah! Great to see interest in HMA support.

Just wanted to share some context on the approach of the FS connector. The HMA support for offloading is primarily being implemented in the vllm offloading connector itself, not in the FS backend. This allows a clean architectural separation between what belongs in the connector (features that should be shared across multiple backends) and what belongs in backends like the fs backend (fast store/load to shared storage). As a result, the HMA changes in the fs connector are minimal- mostly interface adoption.

The relevant vllm PRs are currently under review here:
https://github.com/vllm-project/vllm/pulls?q=is%3Apr+is%3Aopen+label%3Akv-connector++%22kv_offload%2Bhma+%22

Once those are merged, we have a PR ready with the minimal fs connector changes: #476 (tracking issue: #472)

Given your experience with HMA models, I would really appreciate your review and feedback on the open vllm PRs and the llm-d fs connector adoption.

@malaiwah
Copy link
Copy Markdown
Author

malaiwah commented Apr 9, 2026

Interesting, I'll try to have a look.

@malaiwah malaiwah marked this pull request as draft April 9, 2026 12:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Hybrid KV cache support for mamba+attention models (Qwen3.5)

2 participants