Skip to content

Commit 484fe7b

Browse files
committed
Fix cache handling, decode correctness, and no-grad execution path
1 parent ba9edd8 commit 484fe7b

1 file changed

Lines changed: 96 additions & 42 deletions

File tree

src/kvboost/cuda_graph_decode.py

Lines changed: 96 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
cache. This module does the same for kvboost:
77
88
1. After kvboost's (reuse-based) prefill produces the prompt KV, copy it into
9-
a HuggingFace ``StaticCache`` (pre-allocated, fixed-address buffers) via the
10-
cache's own ``update`` API — so reuse/TTFT is preserved, decode just runs
11-
on a graph-capturable, static-shape cache.
9+
a HuggingFace ``StaticCache`` (pre-allocated, fixed-address buffers) via
10+
direct tensor writes — so reuse/TTFT is preserved, decode just runs on a
11+
graph-capturable, static-shape cache.
1212
2. Capture the single-token decode forward into CUDA graphs via
1313
``torch.compile(mode="reduce-overhead")``, replayed per token; sampling is
1414
eager and outside the graph.
@@ -50,16 +50,21 @@
5050
log = logging.getLogger("kvboost.cuda_graph_decode")
5151

5252

53-
def _iter_kv(past_kv):
54-
"""Yield (key, value) per layer for DynamicCache (5.x ``layers`` or older
55-
``key_cache``) or a tuple-of-tuples legacy cache."""
53+
def _iter_kv(past_kv) -> List[tuple]:
54+
"""Return ``[(key_tensor, value_tensor), ...]`` for any supported cache
55+
format. Never returns the StaticCache — only called on the *input* KV
56+
(DynamicCache or tuple-of-tuples) that we copy *into* the StaticCache."""
5657
if past_kv is None:
5758
return []
58-
if hasattr(past_kv, "layers"): # transformers 5.x DynamicCache
59-
return [(l.keys, l.values) for l in past_kv.layers]
60-
if hasattr(past_kv, "key_cache"): # older DynamicCache
59+
# transformers 5.x: both DynamicCache and StaticCache have `.layers`.
60+
# DynamicLayer / StaticLayer both expose `.keys` / `.values`.
61+
if hasattr(past_kv, "layers"):
62+
return [(layer.keys, layer.values) for layer in past_kv.layers]
63+
# Older DynamicCache (4.x): flat key_cache / value_cache lists.
64+
if hasattr(past_kv, "key_cache"):
6165
return list(zip(past_kv.key_cache, past_kv.value_cache))
62-
return [(k, v) for (k, v) in past_kv] # tuple-of-tuples
66+
# Legacy tuple-of-tuples: ((k0, v0), (k1, v1), ...)
67+
return list(past_kv)
6368

6469

6570
@dataclass
@@ -119,9 +124,11 @@ def applicable(self, batch_size: int = 1) -> bool:
119124
# ------------------------------------------------------------------
120125
def _dims(self):
121126
c = self._config
122-
n_heads = int(getattr(c, "num_attention_heads"))
127+
n_heads = int(getattr(c, "num_attention_heads", 0) or 1)
123128
n_kv = int(getattr(c, "num_key_value_heads", n_heads))
124-
head_dim = int(getattr(c, "head_dim", 0) or (c.hidden_size // n_heads))
129+
head_dim = int(getattr(c, "head_dim", 0) or (
130+
c.hidden_size // n_heads if hasattr(c, "hidden_size") else 64
131+
))
125132
return n_kv, head_dim
126133

127134
def _compiled(self):
@@ -170,9 +177,9 @@ def _build(self) -> _Captured:
170177
for b in (inp, pos, cpos):
171178
_dyn.mark_static_address(b)
172179
for lyr in getattr(cache, "layers", []):
173-
if hasattr(lyr, "keys"):
180+
if hasattr(lyr, "keys") and lyr.is_initialized:
174181
_dyn.mark_static_address(lyr.keys)
175-
if hasattr(lyr, "values"):
182+
if hasattr(lyr, "values") and lyr.is_initialized:
176183
_dyn.mark_static_address(lyr.values)
177184
except Exception:
178185
pass
@@ -184,15 +191,24 @@ def _get_cache(self) -> _Captured:
184191
return self._cache
185192

186193
def _populate(self, cache, past_kv, L: int) -> None:
187-
"""Copy the prefill KV (length L) into the static cache via its update
188-
API, which also sets the length counter to L."""
194+
"""Copy the prefill KV (first L positions) into the static cache.
195+
196+
Resets the cache first (zeroes tensors + cumulative_length counters),
197+
then writes each layer's KV directly into the static buffers.
198+
StaticLayer.update auto-increments cumulative_length by the number of
199+
positions written — so after this call cumulative_length == L for every
200+
layer, which is what the decode loop's position tracking assumes.
201+
"""
189202
cache.reset()
190-
cpos = torch.arange(L, device=self.device)
191203
for i, (k, v) in enumerate(_iter_kv(past_kv)):
192204
k = k[:, :, :L, :].to(self.device, self.dtype)
193205
v = v[:, :, :L, :].to(self.device, self.dtype)
194-
cache.update(k, v, i, cache_kwargs={"cache_position": cpos})
206+
# StaticLayer.update ignores any kwargs; it tracks position via its
207+
# internal cumulative_length tensor. After reset() it starts at 0
208+
# and advances by the number of positions we write (= L here).
209+
cache.update(k, v, i)
195210

211+
@torch.no_grad()
196212
def _forward(self, cap: _Captured):
197213
fn = self._compiled() if self._use_compiled else None
198214
target = fn if fn is not None else self.model
@@ -204,34 +220,60 @@ def _forward(self, cap: _Captured):
204220
use_cache=True,
205221
)
206222

223+
@torch.no_grad()
207224
def _step_logits(self, cap: _Captured, tok: int, cur: int) -> torch.Tensor:
225+
"""Run one decode step: feed token ``tok`` at position ``cur``,
226+
return logits ``(1, vocab)`` for the next position.
227+
228+
Sets the three static input buffers in-place (avoids a new tensor
229+
allocation per step — required for CUDA-graph stable addresses).
230+
"""
208231
cap.input_ids[0, 0] = tok
209232
cap.pos_ids[0, 0] = cur
210233
cap.cache_pos[0] = cur
211234
out = self._forward(cap)
212235
return out.logits[:, -1, :]
213236

237+
@torch.no_grad()
214238
def _self_check(self, cap: _Captured, past_kv, L: int, seed: int,
215-
as_cache, k: int = 4) -> bool:
239+
k: int = 4) -> bool:
216240
"""Compare the first ``k`` GREEDY tokens from the compiled-graph path
217241
against an eager reference (original model, fresh DynamicCache). Catches
218242
capture/compile bugs (incl. a frozen mask) before any token is emitted.
219-
Mutates cap.cache (caller re-populates before the real loop)."""
243+
244+
IMPORTANT: builds a *copy* of past_kv for the reference decode so the
245+
original is not mutated in-place (DynamicCache is extended by the model
246+
on every step; sharing it would corrupt the caller's view of the prefill
247+
KV and shift the layer sequence lengths seen by subsequent _populate
248+
calls).
249+
"""
250+
from transformers import DynamicCache
251+
252+
# ── Reference: eager model + fresh DynamicCache copy ─────────────
220253
ref: List[int] = []
221-
with torch.no_grad():
222-
dyn = as_cache(past_kv)
223-
tok, cur = seed, L
224-
for _ in range(k):
225-
o = self.model(
226-
input_ids=torch.tensor([[tok]], device=self.device),
227-
position_ids=torch.tensor([[cur]], device=self.device),
228-
past_key_values=dyn, use_cache=True,
229-
)
230-
tok = int(o.logits[:, -1, :].argmax(-1).item())
231-
ref.append(tok)
232-
cur += 1
233-
if tok == self.eos:
234-
break
254+
ref_kv = DynamicCache()
255+
for i, (lk, lv) in enumerate(_iter_kv(past_kv)):
256+
ref_kv.update(
257+
lk[:, :, :L, :].clone().to(self.device),
258+
lv[:, :, :L, :].clone().to(self.device),
259+
i,
260+
)
261+
262+
tok, cur = seed, L
263+
for _ in range(k):
264+
o = self.model(
265+
input_ids=torch.tensor([[tok]], device=self.device),
266+
position_ids=torch.tensor([[cur]], device=self.device),
267+
past_key_values=ref_kv,
268+
use_cache=True,
269+
)
270+
tok = int(o.logits[:, -1, :].argmax(-1).item())
271+
ref.append(tok)
272+
cur += 1
273+
if tok == self.eos:
274+
break
275+
276+
# ── Compiled / eager-static path ─────────────────────────────────
235277
self._populate(cap.cache, past_kv, L)
236278
got: List[int] = []
237279
tok, cur = seed, L
@@ -242,6 +284,7 @@ def _self_check(self, cap: _Captured, past_kv, L: int, seed: int,
242284
cur += 1
243285
if tok == self.eos:
244286
break
287+
245288
ok = got == ref
246289
if ok:
247290
log.info("CUDA-graph decode self-check passed (%d greedy tokens "
@@ -251,6 +294,7 @@ def _self_check(self, cap: _Captured, past_kv, L: int, seed: int,
251294
"%s) — DISABLING, eager fallback.", got, ref)
252295
return ok
253296

297+
@torch.no_grad()
254298
def _measure_speedup(self, cc: _Captured, past_kv, L: int, seed: int,
255299
steps: int = 12) -> None:
256300
"""Time the compiled step vs an eager step on the SAME static cache and
@@ -267,25 +311,35 @@ def _measure_speedup(self, cc: _Captured, past_kv, L: int, seed: int,
267311
if compiled is None:
268312
return
269313

314+
kw = dict(input_ids=cc.input_ids, position_ids=cc.pos_ids,
315+
cache_position=cc.cache_pos, past_key_values=cc.cache,
316+
use_cache=True)
317+
270318
def _run(fn) -> float:
319+
# Reset cache to a clean state for every timing run so that
320+
# cumulative_length starts at L (not L + previous-step-count).
321+
# Without this, each fn() writes at an ever-increasing position
322+
# and the warmup / timed runs operate on different cache states.
271323
self._populate(cc.cache, past_kv, L)
272324
cc.input_ids[0, 0] = seed
273325
cc.pos_ids[0, 0] = L
274326
cc.cache_pos[0] = L
275327
for _ in range(3): # warm
276-
fn()
328+
fn(**kw)
329+
# Reset between warmup steps so each writes at position L
330+
self._populate(cc.cache, past_kv, L)
331+
cc.cache_pos[0] = L
277332
torch.cuda.synchronize()
278333
t0 = time.perf_counter()
279334
for _ in range(steps):
280-
fn()
335+
fn(**kw)
336+
self._populate(cc.cache, past_kv, L)
337+
cc.cache_pos[0] = L
281338
torch.cuda.synchronize()
282339
return (time.perf_counter() - t0) / steps * 1000.0
283340

284-
kw = dict(input_ids=cc.input_ids, position_ids=cc.pos_ids,
285-
cache_position=cc.cache_pos, past_key_values=cc.cache,
286-
use_cache=True)
287-
ms_compiled = _run(lambda: compiled(**kw))
288-
ms_eager = _run(lambda: self.model(**kw))
341+
ms_compiled = _run(compiled)
342+
ms_eager = _run(self.model)
289343
ratio = ms_eager / max(ms_compiled, 1e-6)
290344
if ratio >= 1.15:
291345
log.info("CUDA-graph decode speedup: %.1f→%.1f ms/step "
@@ -325,7 +379,7 @@ def decode(self, *, past_kv, start_pos: int, seed_token: int,
325379
# token is emitted, so a failure cleanly falls back to eager.
326380
if self._use_compiled and not self._self_checked:
327381
self._self_checked = True
328-
if not self._self_check(cc, past_kv, L, seed_token, as_cache):
382+
if not self._self_check(cc, past_kv, L, seed_token):
329383
self._disabled = True
330384
return None
331385
# Self-check only proves CORRECTNESS. Also measure SPEED so a

0 commit comments

Comments
 (0)