99 * KV cache "low" → lower streaming residency (``keep_first_k`` /
1010 ``keep_last_k``). Layers drop out of VRAM, costing streaming overhead
1111 but unblocking the request.
12+ * Cache nearly empty AND streaming disabled → lower ``prefill_chunk_size``
13+ on the engine. Cuts peak prefill activation memory (the real culprit
14+ for long prompts, where neither knob above can help).
1215
13- A repeated OOM with the *same* knob just adjusted flips to the other knob,
16+ A repeated OOM with the *same* knob just adjusted flips to the next knob,
1417so a single persistent bottleneck still gets fully addressed. Each knob has
15- a floor (``MIN_CACHE_BYTES`` / ``MIN_KEEP``); when both are exhausted the
16- original exception re-raises.
18+ a floor (``MIN_CACHE_BYTES`` / ``MIN_KEEP`` / ``MIN_PREFILL_CHUNK`` ); when
19+ all are exhausted the original exception re-raises.
1720
1821This module is used by:
1922 * the kvboost inference server (``kvboost.server.engine_worker``)
@@ -45,20 +48,24 @@ class Diagnosis:
4548 once per request to avoid masking real OOMs.
4649 * ``"cache_dominant"`` — cache holds ≥ 1.5× tried-alloc; eviction will
4750 free more than the failed allocation needs.
51+ * ``"prefill_bound"`` — cache holds < 0.5× tried-alloc AND streaming
52+ can't help (disabled or exhausted); the OOM is prefill activation
53+ memory, not KV cache. Lower ``prefill_chunk_size`` on the engine.
4854 * ``"residency_bound"`` — cache holds < 0.5× tried-alloc; cache can't
4955 plausibly close the gap, must lower streaming residency.
50- * ``"mixed"`` — neither knob is the obvious culprit; pick the one with
56+ * ``"mixed"`` — no knob is the obvious culprit; pick the one with
5157 more remaining headroom (% from floor), cache wins ties (cheaper).
5258 """
5359
5460 tier : str
55- action : str # "empty_cache_only" | "lower_cache" | "lower_streaming"
61+ action : str # "empty_cache_only" | "lower_cache" | "lower_streaming" | "lower_prefill_chunk"
5662 reason : str
5763 parsed_oom : Dict [str , Optional [float ]] = field (default_factory = dict )
5864 cache_used_mb : float = 0.0
5965 cache_budget_mb : float = 0.0
6066 cache_headroom_frac : float = 0.0
6167 stream_headroom_frac : float = 0.0
68+ prefill_headroom_frac : float = 0.0
6269
6370
6471# Regexes for torch's CUDA OOM message format. Defensive: any field that
@@ -126,8 +133,13 @@ class OOMRecovery:
126133
127134 CACHE_SHRINK = 0.7
128135 STREAM_SHRINK = 0.5
136+ PREFILL_SHRINK = 0.5
129137 MIN_CACHE_BYTES = int (2.5e8 ) # 250 MB floor
130138 MIN_KEEP = 1 # absolute floor — at least 1 resident layer each side
139+ MIN_PREFILL_CHUNK = 32 # absolute floor for prefill chunk size
140+ # Starting value when the engine was launched with prefill_chunk_size=0
141+ # (single-shot prefill) and we need to enable chunking on first OOM.
142+ INITIAL_PREFILL_CHUNK_ON_OOM = 2048
131143 SAFETY_CAP = 16 # absolute attempt cap to avoid pathological loops
132144
133145 def __init__ (
@@ -138,6 +150,7 @@ def __init__(
138150 initial_keep_first_k : Optional [int ],
139151 initial_keep_last_k : Optional [int ],
140152 streaming_enabled : bool ,
153+ initial_prefill_chunk_size : int = 0 ,
141154 max_retries : Optional [int ] = None ,
142155 ):
143156 """Initialise recovery state.
@@ -152,6 +165,7 @@ def __init__(
152165 self .keep_first_k = initial_keep_first_k
153166 self .keep_last_k = initial_keep_last_k
154167 self .streaming_enabled = streaming_enabled
168+ self .prefill_chunk_size = int (initial_prefill_chunk_size )
155169 self .max_retries = max_retries
156170 self .events : List [Dict [str , Any ]] = []
157171 # Lifetime counter (across all calls) for fragmentation diagnoses;
@@ -195,6 +209,16 @@ def _stream_headroom_frac(self) -> float:
195209 return 0.0
196210 return (self .keep_first_k - self .MIN_KEEP ) / self .keep_first_k
197211
212+ def _prefill_headroom_frac (self ) -> float :
213+ """How much room is left to shrink prefill chunk size. A current value
214+ of 0 (single-shot prefill) counts as full headroom — we can switch to
215+ ``INITIAL_PREFILL_CHUNK_ON_OOM`` on first OOM."""
216+ if self .prefill_chunk_size == 0 :
217+ return 1.0
218+ if self .prefill_chunk_size <= self .MIN_PREFILL_CHUNK :
219+ return 0.0
220+ return (self .prefill_chunk_size - self .MIN_PREFILL_CHUNK ) / self .prefill_chunk_size
221+
198222 # ── 4-tier diagnosis ──
199223 def _diagnose (self , err : BaseException , * , allow_fragmentation_tier : bool ) -> Diagnosis :
200224 """Classify an OOM into one of four tiers and pick the cheapest action.
@@ -208,6 +232,7 @@ def _diagnose(self, err: BaseException, *, allow_fragmentation_tier: bool) -> Di
208232 cache_budget_mb = self .max_cache_bytes / (1024.0 ** 2 )
209233 cache_h = self ._cache_headroom_frac ()
210234 stream_h = self ._stream_headroom_frac ()
235+ prefill_h = self ._prefill_headroom_frac ()
211236
212237 tried = parsed ["tried_alloc_mb" ]
213238 frag = parsed ["reserved_unalloc_mb" ]
@@ -230,25 +255,39 @@ def _diagnose(self, err: BaseException, *, allow_fragmentation_tier: bool) -> Di
230255 cache_budget_mb = cache_budget_mb ,
231256 cache_headroom_frac = cache_h ,
232257 stream_headroom_frac = stream_h ,
258+ prefill_headroom_frac = prefill_h ,
233259 )
234260
235261 # ── No tried_alloc → fall back to a simple budget-fraction heuristic ──
236- # Old behaviour: if cache > HIGH_FRAC × budget, lower cache; else streaming.
262+ # Cache-empty + parseless OOM = almost always prefill activation.
263+ # Pick prefill_chunk first if it has headroom, then fall back to the
264+ # legacy cache-fraction heuristic.
237265 if tried is None :
238- cache_dominant = cache_used_mb > 0.5 * cache_budget_mb
239- action = "lower_cache" if cache_dominant or not self .streaming_enabled else "lower_streaming"
266+ cache_nearly_empty = cache_used_mb < 0.1 * max (cache_budget_mb , 1.0 )
267+ if cache_nearly_empty and prefill_h > 0 :
268+ action = "lower_prefill_chunk"
269+ reason = (
270+ f"OOM message unparseable; cache is nearly empty "
271+ f"({ cache_used_mb :.0f} /{ cache_budget_mb :.0f} MiB) "
272+ f"→ lowering prefill chunk"
273+ )
274+ else :
275+ cache_dominant = cache_used_mb > 0.5 * cache_budget_mb
276+ action = "lower_cache" if cache_dominant or not self .streaming_enabled else "lower_streaming"
277+ reason = (
278+ f"OOM message unparseable; legacy heuristic: cache_used "
279+ f"{ cache_used_mb :.0f} /{ cache_budget_mb :.0f} MiB → { action } "
280+ )
240281 return Diagnosis (
241282 tier = "mixed" ,
242283 action = action ,
243- reason = (
244- f"OOM message unparseable; legacy heuristic: cache_used "
245- f"{ cache_used_mb :.0f} /{ cache_budget_mb :.0f} MiB → { action } "
246- ),
284+ reason = reason ,
247285 parsed_oom = parsed ,
248286 cache_used_mb = cache_used_mb ,
249287 cache_budget_mb = cache_budget_mb ,
250288 cache_headroom_frac = cache_h ,
251289 stream_headroom_frac = stream_h ,
290+ prefill_headroom_frac = prefill_h ,
252291 )
253292
254293 # ── Tier 2: Cache dominant ──
@@ -266,9 +305,33 @@ def _diagnose(self, err: BaseException, *, allow_fragmentation_tier: bool) -> Di
266305 cache_budget_mb = cache_budget_mb ,
267306 cache_headroom_frac = cache_h ,
268307 stream_headroom_frac = stream_h ,
308+ prefill_headroom_frac = prefill_h ,
309+ )
310+
311+ # ── Tier 3a: Prefill bound (cache can't help; streaming can't either) ──
312+ # cache_used << tried and either streaming is disabled or already at
313+ # its floor → the OOM is almost certainly prefill activation memory.
314+ # Lower prefill_chunk_size on the engine to cap peak per-step memory.
315+ if (cache_used_mb < self .TIER_RESIDENCY_BOUND * tried
316+ and prefill_h > 0
317+ and (not self .streaming_enabled or stream_h <= 0.0 )):
318+ return Diagnosis (
319+ tier = "prefill_bound" ,
320+ action = "lower_prefill_chunk" ,
321+ reason = (
322+ f"cache_used { cache_used_mb :.0f} MiB << tried-alloc "
323+ f"{ tried :.0f} MiB and streaming can't help — OOM is "
324+ f"prefill activation, lowering chunk size"
325+ ),
326+ parsed_oom = parsed ,
327+ cache_used_mb = cache_used_mb ,
328+ cache_budget_mb = cache_budget_mb ,
329+ cache_headroom_frac = cache_h ,
330+ stream_headroom_frac = stream_h ,
331+ prefill_headroom_frac = prefill_h ,
269332 )
270333
271- # ── Tier 3 : Residency bound ──
334+ # ── Tier 3b : Residency bound ──
272335 if cache_used_mb < self .TIER_RESIDENCY_BOUND * tried and self .streaming_enabled :
273336 return Diagnosis (
274337 tier = "residency_bound" ,
@@ -283,25 +346,29 @@ def _diagnose(self, err: BaseException, *, allow_fragmentation_tier: bool) -> Di
283346 cache_budget_mb = cache_budget_mb ,
284347 cache_headroom_frac = cache_h ,
285348 stream_headroom_frac = stream_h ,
349+ prefill_headroom_frac = prefill_h ,
286350 )
287351
288352 # ── Tier 4: Mixed ──
289- # Cache wins ties because eviction is cheaper (and stream headroom is
290- # -1.0 when streaming is off, so cache also wins by default there).
291- action = "lower_streaming" if stream_h > cache_h else "lower_cache"
353+ # Highest remaining headroom wins; cache breaks ties (cheapest).
354+ candidates = [("lower_cache" , cache_h ), ("lower_streaming" , stream_h ),
355+ ("lower_prefill_chunk" , prefill_h )]
356+ action = max (candidates , key = lambda kv : kv [1 ])[0 ]
292357 return Diagnosis (
293358 tier = "mixed" ,
294359 action = action ,
295360 reason = (
296361 f"ambiguous (cache_used { cache_used_mb :.0f} MiB ≈ tried "
297362 f"{ tried :.0f} MiB): cache_headroom={ cache_h :.0%} , "
298- f"stream_headroom={ stream_h :.0%} → { action } "
363+ f"stream_headroom={ stream_h :.0%} , "
364+ f"prefill_headroom={ prefill_h :.0%} → { action } "
299365 ),
300366 parsed_oom = parsed ,
301367 cache_used_mb = cache_used_mb ,
302368 cache_budget_mb = cache_budget_mb ,
303369 cache_headroom_frac = cache_h ,
304370 stream_headroom_frac = stream_h ,
371+ prefill_headroom_frac = prefill_h ,
305372 )
306373
307374 # ── Knob adjusters ──
@@ -390,6 +457,31 @@ def _lower_streaming(self) -> Optional[Dict[str, Any]]:
390457 "keep_last_k" : new_last ,
391458 }
392459
460+ def _lower_prefill_chunk (self ) -> Optional [Dict [str , Any ]]:
461+ """Reduce engine.prefill_chunk_size to cap peak prefill activation
462+ memory. From 0 (single-shot) we jump to ``INITIAL_PREFILL_CHUNK_ON_OOM``;
463+ otherwise we halve down to ``MIN_PREFILL_CHUNK``. Returns None when the
464+ engine doesn't expose the knob or we've hit the floor."""
465+ if not hasattr (self .engine , "prefill_chunk_size" ):
466+ return None
467+ old = self .prefill_chunk_size
468+ if old == 0 :
469+ new = self .INITIAL_PREFILL_CHUNK_ON_OOM
470+ else :
471+ new = max (self .MIN_PREFILL_CHUNK , int (old * self .PREFILL_SHRINK ))
472+ if new >= old :
473+ return None
474+ try :
475+ self .engine .prefill_chunk_size = new
476+ except Exception :
477+ return None
478+ self .prefill_chunk_size = new
479+ return {
480+ "action" : "lower_prefill_chunk" ,
481+ "old_prefill_chunk_size" : old ,
482+ "new_prefill_chunk_size" : new ,
483+ }
484+
393485 # ── Driver ──
394486 def attempt (
395487 self ,
@@ -527,32 +619,45 @@ def _apply_diagnosis(
527619 # ── Tiers 2-4: knob change ──
528620 # For cache-action tiers, _lower_cache calls cm.clear() itself, so we
529621 # don't pre-emptively reset_cache (that would lose chunks before the
530- # knob even applies). For residency tiers, we keep the cache intact —
531- # losing it doesn't help residency pressure and gives up valuable reuse.
622+ # knob even applies). For residency / prefill tiers, we keep the cache
623+ # intact — losing it doesn't help and gives up valuable reuse.
624+ knobs = {
625+ "lower_cache" : self ._lower_cache ,
626+ "lower_streaming" : self ._lower_streaming ,
627+ "lower_prefill_chunk" : self ._lower_prefill_chunk ,
628+ }
532629 primary_name = dx .action
533- secondary_name = "lower_streaming" if primary_name == "lower_cache" else "lower_cache"
534- primary = self ._lower_cache if primary_name == "lower_cache" else self ._lower_streaming
535- secondary = self ._lower_streaming if primary_name == "lower_cache" else self ._lower_cache
536-
537- change = primary ()
538- flipped = False
539- if change is None :
540- log .warning (
541- "OOM recovery: diagnosed action '%s' unavailable (floor reached or disabled); "
542- "falling back to '%s'." ,
543- primary_name , secondary_name ,
544- )
545- change = secondary ()
546- flipped = True
630+ # Fallback order: try the diagnosed knob first, then the others in a
631+ # fixed order (cache → prefill → streaming). Prefill before streaming
632+ # because most cache-empty OOMs are prefill activations, not weights.
633+ fallback_chain = [primary_name ] + [
634+ k for k in ("lower_cache" , "lower_prefill_chunk" , "lower_streaming" )
635+ if k != primary_name
636+ ]
637+
638+ change = None
639+ flipped_to = None
640+ for i , knob_name in enumerate (fallback_chain ):
641+ change = knobs [knob_name ]()
642+ if change is not None :
643+ if i > 0 :
644+ flipped_to = knob_name
645+ log .warning (
646+ "OOM recovery: diagnosed action '%s' unavailable; "
647+ "falling back to '%s'." ,
648+ primary_name , knob_name ,
649+ )
650+ break
547651
548652 if change is None :
549653 log .error (
550654 "OOM recovery EXHAUSTED at attempt %d (tier=%s): "
551- "cache=%.2f GB (floor=%.2f GB), keep=%s/%s (floor=%d). "
552- "Re-raising original CUDA OOM." ,
655+ "cache=%.2f GB (floor=%.2f GB), keep=%s/%s (floor=%d), "
656+ "prefill_chunk=%d (floor=%d). Re-raising original CUDA OOM." ,
553657 attempt_idx , dx .tier ,
554658 self .max_cache_bytes / 1e9 , self .MIN_CACHE_BYTES / 1e9 ,
555659 self .keep_first_k , self .keep_last_k , self .MIN_KEEP ,
660+ self .prefill_chunk_size , self .MIN_PREFILL_CHUNK ,
556661 )
557662 return None
558663
@@ -565,31 +670,39 @@ def _apply_diagnosis(
565670 change ["cache_budget_mb" ] = dx .cache_budget_mb
566671 change ["cache_headroom_frac" ] = dx .cache_headroom_frac
567672 change ["stream_headroom_frac" ] = dx .stream_headroom_frac
568- change ["flipped_to_secondary" ] = flipped
673+ change ["prefill_headroom_frac" ] = dx .prefill_headroom_frac
674+ change ["flipped_to_secondary" ] = flipped_to is not None
675+ if flipped_to is not None :
676+ change ["flipped_to" ] = flipped_to
569677 self .events .append (change )
570678
571- if change ["action" ] == "lower_cache" :
679+ action = change ["action" ]
680+ if action == "lower_cache" :
572681 summary = (
573682 f"lower_cache: max_cache_bytes "
574683 f"{ change ['old_max_cache_bytes' ] / 1e9 :.2f} GB → "
575684 f"{ change ['new_max_cache_bytes' ] / 1e9 :.2f} GB"
576685 )
577- change ["summary" ] = summary
578- log .warning (
579- "OOM recovery → %s (×%.2f, tier=%s, attempt=%d). Will retry." ,
580- summary , self .CACHE_SHRINK , dx .tier , attempt_idx ,
581- )
582- else :
686+ shrink = self .CACHE_SHRINK
687+ elif action == "lower_streaming" :
583688 summary = (
584689 f"lower_streaming: keep_first_k { change ['old_keep_first_k' ]} →"
585690 f"{ change ['keep_first_k' ]} , keep_last_k "
586691 f"{ change ['old_keep_last_k' ] or 0 } →{ change ['keep_last_k' ]} "
587692 )
588- change ["summary" ] = summary
589- log .warning (
590- "OOM recovery → %s (×%.2f, tier=%s, attempt=%d). Will retry." ,
591- summary , self .STREAM_SHRINK , dx .tier , attempt_idx ,
693+ shrink = self .STREAM_SHRINK
694+ else : # lower_prefill_chunk
695+ summary = (
696+ f"lower_prefill_chunk: prefill_chunk_size "
697+ f"{ change ['old_prefill_chunk_size' ]} → "
698+ f"{ change ['new_prefill_chunk_size' ]} "
592699 )
700+ shrink = self .PREFILL_SHRINK
701+ change ["summary" ] = summary
702+ log .warning (
703+ "OOM recovery → %s (×%.2f, tier=%s, attempt=%d). Will retry." ,
704+ summary , shrink , dx .tier , attempt_idx ,
705+ )
593706 return change
594707
595708 def snapshot (self ) -> Dict [str , Any ]:
@@ -600,5 +713,6 @@ def snapshot(self) -> Dict[str, Any]:
600713 "max_cache_bytes" : self .max_cache_bytes ,
601714 "keep_first_k" : self .keep_first_k ,
602715 "keep_last_k" : self .keep_last_k ,
716+ "prefill_chunk_size" : self .prefill_chunk_size ,
603717 "events" : list (self .events ),
604718 }
0 commit comments