Skip to content

Commit 9081de3

Browse files
authored
Fix prefix cache restore to set KV offset explicitly (#144)
This PR is: - To make prefix-cache restore robust by explicitly restoring `KVCache.offset` from cached KV tensor length. - To avoid relying on `KVCache.state` setter side-effects for position state. - To keep RoPE position continuity correct after prefix-cache hits. - To add a focused regression test that fails if offset is not explicitly restored. ### Additional note Restoring only `state` is not sufficient if the cache implementation does not update `offset` as a side-effect. If `offset` remains `0` after restore, subsequent decode can use incorrect positions after a prefix cache hit. ### Reproduce code ```python from unittest.mock import MagicMock import mlx.core as mx import vllm_metal.v1.model_runner as mr class KVNoOffsetSideEffect: # Simulate a cache object where assigning .state does NOT update .offset. def __init__(self): self._state = [None, None] self.offset = 0 @Property def state(self): return self._state @state.setter def state(self, value): self._state = value def fake_make_prompt_cache(_): # Restore will create fresh cache layers from this factory. return [KVNoOffsetSideEffect()] orig_kv, orig_make = mr.KVCache, mr.make_prompt_cache mr.KVCache, mr.make_prompt_cache = KVNoOffsetSideEffect, fake_make_prompt_cache try: # Note: token_ids length (3) is intentionally different from KV seq_len (7). # This shows offset restore comes from KV shape, not token_ids metadata. k = mx.zeros((1, 2, 7, 8), dtype=mx.float32) v = mx.zeros((1, 2, 7, 8), dtype=mx.float32) cached = mr.CachedPrefix(token_ids=[1, 2, 3], cache_state=[(k, v)]) restored = mr.PrefixCacheManager(max_bytes=1024 * 1024).restore_cache( cached, model=MagicMock(), is_vlm=False ) # Expected output after fix: restored_offset=7 print("restored_offset=", restored[0].offset) finally: mr.KVCache, mr.make_prompt_cache = orig_kv, orig_make ``` Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
1 parent 97a2844 commit 9081de3

2 files changed

Lines changed: 38 additions & 0 deletions

File tree

tests/test_prefix_cache.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,41 @@ def fake_make_prompt_cache(model):
9393
insert_spy.assert_called_once()
9494

9595

96+
class TestPrefixCacheRestoreOffset:
97+
class _KVCacheWithoutOffsetSideEffect:
98+
def __init__(self) -> None:
99+
self._state: list[mx.array | None] = [None, None]
100+
self.offset = 0
101+
102+
@property
103+
def state(self) -> list[mx.array | None]:
104+
return self._state
105+
106+
@state.setter
107+
def state(self, value: list[mx.array]) -> None:
108+
# Intentionally does not mutate offset.
109+
self._state = value
110+
111+
def test_restore_cache_sets_offset_explicitly(self, monkeypatch) -> None:
112+
def fake_make_prompt_cache(_model):
113+
return [self._KVCacheWithoutOffsetSideEffect()]
114+
115+
monkeypatch.setattr(mr, "KVCache", self._KVCacheWithoutOffsetSideEffect)
116+
monkeypatch.setattr(mr, "make_prompt_cache", fake_make_prompt_cache)
117+
118+
k = mx.zeros((1, 2, 7, 8), dtype=mx.float32)
119+
v = mx.zeros((1, 2, 7, 8), dtype=mx.float32)
120+
cached = mr.CachedPrefix(token_ids=[1, 2, 3], cache_state=[(k, v)])
121+
122+
manager = mr.PrefixCacheManager(max_bytes=1024 * 1024)
123+
restored = manager.restore_cache(cached, model=MagicMock(), is_vlm=False)
124+
125+
restored_layer = restored[0]
126+
assert restored_layer.offset == 7
127+
assert bool(mx.allclose(restored_layer.state[0], k))
128+
assert bool(mx.allclose(restored_layer.state[1], v))
129+
130+
96131
class TestHybridCacheMergeExtract:
97132
"""Regression tests for hybrid (KV + ArraysCache) batching.
98133

vllm_metal/v1/model_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,9 @@ def restore_cache(
278278
if isinstance(layer_cache, KVCache):
279279
k, v = cached.cache_state[i]
280280
layer_cache.state = [mx.array(k), mx.array(v)]
281+
# Keep RoPE position correct even if KVCache.state setter
282+
# behavior changes in future mlx-lm versions.
283+
layer_cache.offset = int(k.shape[2])
281284
return cache
282285

283286
@property

0 commit comments

Comments
 (0)