66cache. This module does the same for kvboost:
77
88 1. After kvboost's (reuse-based) prefill produces the prompt KV, copy it into
9- a HuggingFace ``StaticCache`` (pre-allocated, fixed-address buffers) via the
10- cache's own ``update`` API — so reuse/TTFT is preserved, decode just runs
11- on a graph-capturable, static-shape cache.
9+ a HuggingFace ``StaticCache`` (pre-allocated, fixed-address buffers) via
10+ direct tensor writes — so reuse/TTFT is preserved, decode just runs on a
11+ graph-capturable, static-shape cache.
1212 2. Capture the single-token decode forward into CUDA graphs via
1313 ``torch.compile(mode="reduce-overhead")``, replayed per token; sampling is
1414 eager and outside the graph.
5050log = logging .getLogger ("kvboost.cuda_graph_decode" )
5151
5252
53- def _iter_kv (past_kv ):
54- """Yield (key, value) per layer for DynamicCache (5.x ``layers`` or older
55- ``key_cache``) or a tuple-of-tuples legacy cache."""
53+ def _iter_kv (past_kv ) -> List [tuple ]:
54+ """Return ``[(key_tensor, value_tensor), ...]`` for any supported cache
55+ format. Never returns the StaticCache — only called on the *input* KV
56+ (DynamicCache or tuple-of-tuples) that we copy *into* the StaticCache."""
5657 if past_kv is None :
5758 return []
58- if hasattr (past_kv , "layers" ): # transformers 5.x DynamicCache
59- return [(l .keys , l .values ) for l in past_kv .layers ]
60- if hasattr (past_kv , "key_cache" ): # older DynamicCache
59+ # transformers 5.x: both DynamicCache and StaticCache have `.layers`.
60+ # DynamicLayer / StaticLayer both expose `.keys` / `.values`.
61+ if hasattr (past_kv , "layers" ):
62+ return [(layer .keys , layer .values ) for layer in past_kv .layers ]
63+ # Older DynamicCache (4.x): flat key_cache / value_cache lists.
64+ if hasattr (past_kv , "key_cache" ):
6165 return list (zip (past_kv .key_cache , past_kv .value_cache ))
62- return [(k , v ) for (k , v ) in past_kv ] # tuple-of-tuples
66+ # Legacy tuple-of-tuples: ((k0, v0), (k1, v1), ...)
67+ return list (past_kv )
6368
6469
6570@dataclass
@@ -119,9 +124,11 @@ def applicable(self, batch_size: int = 1) -> bool:
119124 # ------------------------------------------------------------------
120125 def _dims (self ):
121126 c = self ._config
122- n_heads = int (getattr (c , "num_attention_heads" ) )
127+ n_heads = int (getattr (c , "num_attention_heads" , 0 ) or 1 )
123128 n_kv = int (getattr (c , "num_key_value_heads" , n_heads ))
124- head_dim = int (getattr (c , "head_dim" , 0 ) or (c .hidden_size // n_heads ))
129+ head_dim = int (getattr (c , "head_dim" , 0 ) or (
130+ c .hidden_size // n_heads if hasattr (c , "hidden_size" ) else 64
131+ ))
125132 return n_kv , head_dim
126133
127134 def _compiled (self ):
@@ -170,9 +177,9 @@ def _build(self) -> _Captured:
170177 for b in (inp , pos , cpos ):
171178 _dyn .mark_static_address (b )
172179 for lyr in getattr (cache , "layers" , []):
173- if hasattr (lyr , "keys" ):
180+ if hasattr (lyr , "keys" ) and lyr . is_initialized :
174181 _dyn .mark_static_address (lyr .keys )
175- if hasattr (lyr , "values" ):
182+ if hasattr (lyr , "values" ) and lyr . is_initialized :
176183 _dyn .mark_static_address (lyr .values )
177184 except Exception :
178185 pass
@@ -184,15 +191,24 @@ def _get_cache(self) -> _Captured:
184191 return self ._cache
185192
186193 def _populate (self , cache , past_kv , L : int ) -> None :
187- """Copy the prefill KV (length L) into the static cache via its update
188- API, which also sets the length counter to L."""
194+ """Copy the prefill KV (first L positions) into the static cache.
195+
196+ Resets the cache first (zeroes tensors + cumulative_length counters),
197+ then writes each layer's KV directly into the static buffers.
198+ StaticLayer.update auto-increments cumulative_length by the number of
199+ positions written — so after this call cumulative_length == L for every
200+ layer, which is what the decode loop's position tracking assumes.
201+ """
189202 cache .reset ()
190- cpos = torch .arange (L , device = self .device )
191203 for i , (k , v ) in enumerate (_iter_kv (past_kv )):
192204 k = k [:, :, :L , :].to (self .device , self .dtype )
193205 v = v [:, :, :L , :].to (self .device , self .dtype )
194- cache .update (k , v , i , cache_kwargs = {"cache_position" : cpos })
206+ # StaticLayer.update ignores any kwargs; it tracks position via its
207+ # internal cumulative_length tensor. After reset() it starts at 0
208+ # and advances by the number of positions we write (= L here).
209+ cache .update (k , v , i )
195210
211+ @torch .no_grad ()
196212 def _forward (self , cap : _Captured ):
197213 fn = self ._compiled () if self ._use_compiled else None
198214 target = fn if fn is not None else self .model
@@ -204,34 +220,60 @@ def _forward(self, cap: _Captured):
204220 use_cache = True ,
205221 )
206222
223+ @torch .no_grad ()
207224 def _step_logits (self , cap : _Captured , tok : int , cur : int ) -> torch .Tensor :
225+ """Run one decode step: feed token ``tok`` at position ``cur``,
226+ return logits ``(1, vocab)`` for the next position.
227+
228+ Sets the three static input buffers in-place (avoids a new tensor
229+ allocation per step — required for CUDA-graph stable addresses).
230+ """
208231 cap .input_ids [0 , 0 ] = tok
209232 cap .pos_ids [0 , 0 ] = cur
210233 cap .cache_pos [0 ] = cur
211234 out = self ._forward (cap )
212235 return out .logits [:, - 1 , :]
213236
237+ @torch .no_grad ()
214238 def _self_check (self , cap : _Captured , past_kv , L : int , seed : int ,
215- as_cache , k : int = 4 ) -> bool :
239+ k : int = 4 ) -> bool :
216240 """Compare the first ``k`` GREEDY tokens from the compiled-graph path
217241 against an eager reference (original model, fresh DynamicCache). Catches
218242 capture/compile bugs (incl. a frozen mask) before any token is emitted.
219- Mutates cap.cache (caller re-populates before the real loop)."""
243+
244+ IMPORTANT: builds a *copy* of past_kv for the reference decode so the
245+ original is not mutated in-place (DynamicCache is extended by the model
246+ on every step; sharing it would corrupt the caller's view of the prefill
247+ KV and shift the layer sequence lengths seen by subsequent _populate
248+ calls).
249+ """
250+ from transformers import DynamicCache
251+
252+ # ── Reference: eager model + fresh DynamicCache copy ─────────────
220253 ref : List [int ] = []
221- with torch .no_grad ():
222- dyn = as_cache (past_kv )
223- tok , cur = seed , L
224- for _ in range (k ):
225- o = self .model (
226- input_ids = torch .tensor ([[tok ]], device = self .device ),
227- position_ids = torch .tensor ([[cur ]], device = self .device ),
228- past_key_values = dyn , use_cache = True ,
229- )
230- tok = int (o .logits [:, - 1 , :].argmax (- 1 ).item ())
231- ref .append (tok )
232- cur += 1
233- if tok == self .eos :
234- break
254+ ref_kv = DynamicCache ()
255+ for i , (lk , lv ) in enumerate (_iter_kv (past_kv )):
256+ ref_kv .update (
257+ lk [:, :, :L , :].clone ().to (self .device ),
258+ lv [:, :, :L , :].clone ().to (self .device ),
259+ i ,
260+ )
261+
262+ tok , cur = seed , L
263+ for _ in range (k ):
264+ o = self .model (
265+ input_ids = torch .tensor ([[tok ]], device = self .device ),
266+ position_ids = torch .tensor ([[cur ]], device = self .device ),
267+ past_key_values = ref_kv ,
268+ use_cache = True ,
269+ )
270+ tok = int (o .logits [:, - 1 , :].argmax (- 1 ).item ())
271+ ref .append (tok )
272+ cur += 1
273+ if tok == self .eos :
274+ break
275+
276+ # ── Compiled / eager-static path ─────────────────────────────────
235277 self ._populate (cap .cache , past_kv , L )
236278 got : List [int ] = []
237279 tok , cur = seed , L
@@ -242,6 +284,7 @@ def _self_check(self, cap: _Captured, past_kv, L: int, seed: int,
242284 cur += 1
243285 if tok == self .eos :
244286 break
287+
245288 ok = got == ref
246289 if ok :
247290 log .info ("CUDA-graph decode self-check passed (%d greedy tokens "
@@ -251,6 +294,7 @@ def _self_check(self, cap: _Captured, past_kv, L: int, seed: int,
251294 "%s) — DISABLING, eager fallback." , got , ref )
252295 return ok
253296
297+ @torch .no_grad ()
254298 def _measure_speedup (self , cc : _Captured , past_kv , L : int , seed : int ,
255299 steps : int = 12 ) -> None :
256300 """Time the compiled step vs an eager step on the SAME static cache and
@@ -267,25 +311,35 @@ def _measure_speedup(self, cc: _Captured, past_kv, L: int, seed: int,
267311 if compiled is None :
268312 return
269313
314+ kw = dict (input_ids = cc .input_ids , position_ids = cc .pos_ids ,
315+ cache_position = cc .cache_pos , past_key_values = cc .cache ,
316+ use_cache = True )
317+
270318 def _run (fn ) -> float :
319+ # Reset cache to a clean state for every timing run so that
320+ # cumulative_length starts at L (not L + previous-step-count).
321+ # Without this, each fn() writes at an ever-increasing position
322+ # and the warmup / timed runs operate on different cache states.
271323 self ._populate (cc .cache , past_kv , L )
272324 cc .input_ids [0 , 0 ] = seed
273325 cc .pos_ids [0 , 0 ] = L
274326 cc .cache_pos [0 ] = L
275327 for _ in range (3 ): # warm
276- fn ()
328+ fn (** kw )
329+ # Reset between warmup steps so each writes at position L
330+ self ._populate (cc .cache , past_kv , L )
331+ cc .cache_pos [0 ] = L
277332 torch .cuda .synchronize ()
278333 t0 = time .perf_counter ()
279334 for _ in range (steps ):
280- fn ()
335+ fn (** kw )
336+ self ._populate (cc .cache , past_kv , L )
337+ cc .cache_pos [0 ] = L
281338 torch .cuda .synchronize ()
282339 return (time .perf_counter () - t0 ) / steps * 1000.0
283340
284- kw = dict (input_ids = cc .input_ids , position_ids = cc .pos_ids ,
285- cache_position = cc .cache_pos , past_key_values = cc .cache ,
286- use_cache = True )
287- ms_compiled = _run (lambda : compiled (** kw ))
288- ms_eager = _run (lambda : self .model (** kw ))
341+ ms_compiled = _run (compiled )
342+ ms_eager = _run (self .model )
289343 ratio = ms_eager / max (ms_compiled , 1e-6 )
290344 if ratio >= 1.15 :
291345 log .info ("CUDA-graph decode speedup: %.1f→%.1f ms/step "
@@ -325,7 +379,7 @@ def decode(self, *, past_kv, start_pos: int, seed_token: int,
325379 # token is emitted, so a failure cleanly falls back to eager.
326380 if self ._use_compiled and not self ._self_checked :
327381 self ._self_checked = True
328- if not self ._self_check (cc , past_kv , L , seed_token , as_cache ):
382+ if not self ._self_check (cc , past_kv , L , seed_token ):
329383 self ._disabled = True
330384 return None
331385 # Self-check only proves CORRECTNESS. Also measure SPEED so a
0 commit comments