1313 on the engine. Cuts peak prefill activation memory (the real culprit
1414 for long prompts, where neither knob above can help).
1515
16- A repeated OOM with the *same* knob just adjusted flips to the next knob,
17- so a single persistent bottleneck still gets fully addressed. Each knob has
18- a floor (``MIN_CACHE_BYTES`` / ``MIN_KEEP`` / ``MIN_PREFILL_CHUNK``); when
19- all are exhausted the original exception re-raises.
16+ Each tier is authoritative: cache-dominant OOMs touch *only* the cache,
17+ prefill-bound OOMs touch *only* prefill_chunk_size, residency-bound OOMs
18+ touch *only* streaming residency. Falling back to a different knob would
19+ damage unrelated state (e.g. dropping KV reuse when the real problem was
20+ prompt length) without addressing the cause. The only tier with cross-knob
21+ fallback is ``mixed`` — by definition genuinely ambiguous — and even there
22+ candidates are filtered to those that can plausibly help. A repeated OOM
23+ gets a fresh diagnosis, so a shifting bottleneck still ends up at the
24+ right knob across attempts. When the diagnosed knob is at its floor, the
25+ original exception re-raises.
2026
2127This module is used by:
2228 * the kvboost inference server (``kvboost.server.engine_worker``)
@@ -376,29 +382,46 @@ def _diagnose(self, err: BaseException, *, allow_fragmentation_tier: bool) -> Di
376382 # decrease is possible (clamped at floor, or initial value was already
377383 # below the floor) is the signal to the caller that this knob is exhausted.
378384 def _lower_cache (self ) -> Optional [Dict [str , Any ]]:
385+ """Free KV cache memory. The freeing happens via ``cm.clear()``;
386+ lowering the budget is just a prophylactic so it doesn't refill to
387+ the same level on the next request.
388+
389+ Returns None when this knob can't help — either the budget is at
390+ floor AND the cache is empty (nothing to clear, no room to shrink),
391+ or ``cm.clear()`` is unavailable. When the budget is at floor but
392+ the cache is still populated, we still evict (budget unchanged).
393+ """
394+ cm = self .engine .cache_manager
379395 old_bytes = self .max_cache_bytes
380- # Apply shrink, then clamp UP to the floor so we never go below it,
381- # then require strict decrease relative to old (which prevents the
382- # floor clamp from ever raising the value).
383396 new_bytes = max (self .MIN_CACHE_BYTES , int (old_bytes * self .CACHE_SHRINK ))
384- if new_bytes >= old_bytes :
397+ cache_used = self ._cache_bytes_used ()
398+
399+ budget_can_shrink = new_bytes < old_bytes
400+ evict_can_help = cache_used > 0
401+
402+ if not budget_can_shrink and not evict_can_help :
385403 return None
386- cm = self .engine .cache_manager
387- for attr in ("max_cache_bytes" , "_max_cache_bytes" , "max_bytes" ):
388- if hasattr (cm , attr ):
389- try :
390- setattr (cm , attr , new_bytes )
391- except Exception :
392- pass
393- self .max_cache_bytes = new_bytes
394- try :
395- cm .clear ()
396- except Exception :
397- pass
404+
405+ if budget_can_shrink :
406+ for attr in ("max_cache_bytes" , "_max_cache_bytes" , "max_bytes" ):
407+ if hasattr (cm , attr ):
408+ try :
409+ setattr (cm , attr , new_bytes )
410+ except Exception :
411+ pass
412+ self .max_cache_bytes = new_bytes
413+
414+ if evict_can_help :
415+ try :
416+ cm .clear ()
417+ except Exception :
418+ pass
419+
398420 return {
399421 "action" : "lower_cache" ,
400422 "old_max_cache_bytes" : old_bytes ,
401- "new_max_cache_bytes" : new_bytes ,
423+ "new_max_cache_bytes" : self .max_cache_bytes ,
424+ "evicted_bytes" : cache_used if evict_can_help else 0 ,
402425 }
403426
404427 def _lower_streaming (self ) -> Optional [Dict [str , Any ]]:
@@ -492,16 +515,17 @@ def attempt(
492515 ) -> Any :
493516 """Call ``fn(*args, **kwargs)`` with OOM-aware retry.
494517
495- On every OOM we run ``_diagnose(err)`` to pick one of four tiers,
496- then ``_apply_diagnosis(...)`` to execute it. Tier 1 (fragmentation)
497- is allowed at most once per request — repeated "fragmentation"
498- diagnoses would mask a real residency/cache OOM. Tiers 2-4 each
499- shrink a knob (or skip when both are exhausted).
518+ On every OOM we run ``_diagnose(err)`` to pick a tier, then
519+ ``_apply_diagnosis(...)`` to execute it. Tier 1 (fragmentation) is
520+ allowed at most once per request — repeated "fragmentation"
521+ diagnoses would mask a real OOM. Non-``mixed`` tiers each touch a
522+ single specific knob (no cross-knob fallback); ``mixed`` tries
523+ candidates filtered by "can plausibly help" in headroom order.
500524
501525 Loop terminates when:
502526 * ``fn`` returns normally
503- * Both knobs hit floor AND fragmentation has been used (no remaining
504- recovery action)
527+ * The diagnosed knob is at its floor (or, in ``mixed``, every
528+ plausibly-helpful knob is exhausted) — re-raise.
505529 * ``can_retry()`` returns False (mid-stream after partial output) —
506530 the knob still gets adjusted for the next request, then re-raise.
507531 * The safety cap is reached.
@@ -539,15 +563,8 @@ def attempt(
539563
540564 change = self ._apply_diagnosis (diagnosis , attempt_idx )
541565 if change is None :
542- # All recovery paths exhausted (knobs at floor, fragmentation already used).
543- log .error (
544- "OOM recovery cannot reduce further: cache=%.2f GB (floor=%.2f GB), "
545- "keep_first_k=%s keep_last_k=%s (floor=%d), fragmentation_used=%s. "
546- "Re-raising." ,
547- self .max_cache_bytes / 1e9 , self .MIN_CACHE_BYTES / 1e9 ,
548- self .keep_first_k , self .keep_last_k , self .MIN_KEEP ,
549- fragmentation_used ,
550- )
566+ # Diagnosed knob (or every plausibly-helpful knob in `mixed`)
567+ # is at floor. _apply_diagnosis already logged the details.
551568 break
552569 oom_events .append (change )
553570
@@ -573,6 +590,41 @@ def attempt(
573590 assert last_err is not None
574591 raise last_err
575592
593+ def _mixed_candidates (self , dx : Diagnosis ) -> List [str ]:
594+ """For the ``mixed`` tier, return knob names ordered by headroom,
595+ keeping only knobs that can plausibly help this OOM.
596+
597+ Filter rules:
598+ * ``lower_cache`` — kept only when eviction would free at least the
599+ failed allocation. ``cm.clear()`` drops at most ``cache_used_mb``;
600+ if that's already less than ``tried_alloc_mb``, the OOM persists
601+ after the clear, so the knob is provably insufficient. When
602+ ``tried`` is unknown, fall back to a 50 MiB absolute threshold.
603+ * ``lower_prefill_chunk`` — kept whenever there's any room to shrink
604+ (``prefill_headroom > 0``).
605+ * ``lower_streaming`` — kept whenever streaming is enabled and has
606+ residency headroom (``stream_headroom > 0``).
607+
608+ The result is sorted by headroom descending; ties break in the order
609+ cache → prefill → streaming (cheapest first).
610+ """
611+ tried = dx .parsed_oom .get ("tried_alloc_mb" )
612+ cache_can_help = (
613+ dx .cache_used_mb >= tried if tried is not None
614+ else dx .cache_used_mb >= 50.0
615+ )
616+
617+ candidates : List [tuple ] = []
618+ if cache_can_help :
619+ candidates .append (("lower_cache" , dx .cache_headroom_frac , 0 ))
620+ if dx .prefill_headroom_frac > 0 :
621+ candidates .append (("lower_prefill_chunk" , dx .prefill_headroom_frac , 1 ))
622+ if dx .stream_headroom_frac > 0 :
623+ candidates .append (("lower_streaming" , dx .stream_headroom_frac , 2 ))
624+
625+ candidates .sort (key = lambda kv : (- kv [1 ], kv [2 ]))
626+ return [name for name , _ , _ in candidates ]
627+
576628 # ── Apply a Diagnosis ──
577629 def _apply_diagnosis (
578630 self , dx : Diagnosis , attempt_idx : int ,
@@ -617,45 +669,53 @@ def _apply_diagnosis(
617669 return event
618670
619671 # ── Tiers 2-4: knob change ──
620- # For cache-action tiers, _lower_cache calls cm.clear() itself, so we
621- # don't pre-emptively reset_cache (that would lose chunks before the
622- # knob even applies). For residency / prefill tiers, we keep the cache
623- # intact — losing it doesn't help and gives up valuable reuse.
672+ # Each non-`mixed` tier is authoritative: it identified the actual
673+ # bottleneck, so we touch *only* that knob. Falling back to a
674+ # different knob would damage unrelated state (e.g. dropping KV
675+ # reuse when the real problem was prompt length) without addressing
676+ # the OOM cause. If the diagnosed knob is exhausted, recovery is
677+ # genuinely out of moves for this tier — re-raise.
678+ #
679+ # `mixed` is the only tier with real ambiguity, so it's the only
680+ # tier that tries multiple knobs. Even there we filter candidates
681+ # to those that can *plausibly* help — a cache-eviction step that
682+ # would free fewer bytes than the failed allocation tried is
683+ # provably useless and gets skipped.
624684 knobs = {
625685 "lower_cache" : self ._lower_cache ,
626686 "lower_streaming" : self ._lower_streaming ,
627687 "lower_prefill_chunk" : self ._lower_prefill_chunk ,
628688 }
629- primary_name = dx .action
630- # Fallback order: try the diagnosed knob first, then the others in a
631- # fixed order (cache → prefill → streaming). Prefill before streaming
632- # because most cache-empty OOMs are prefill activations, not weights.
633- fallback_chain = [primary_name ] + [
634- k for k in ("lower_cache" , "lower_prefill_chunk" , "lower_streaming" )
635- if k != primary_name
636- ]
689+
690+ if dx .tier == "mixed" :
691+ ordered = self ._mixed_candidates (dx )
692+ else :
693+ ordered = [dx .action ]
637694
638695 change = None
639696 flipped_to = None
640- for i , knob_name in enumerate (fallback_chain ):
697+ for i , knob_name in enumerate (ordered ):
641698 change = knobs [knob_name ]()
642699 if change is not None :
643700 if i > 0 :
644701 flipped_to = knob_name
645702 log .warning (
646- "OOM recovery: diagnosed action '%s' unavailable; "
647- "falling back to '%s'." ,
648- primary_name , knob_name ,
703+ "OOM recovery (mixed): primary '%s' unavailable; "
704+ "using '%s' instead ." ,
705+ dx . action , knob_name ,
649706 )
650707 break
651708
652709 if change is None :
653710 log .error (
654- "OOM recovery EXHAUSTED at attempt %d (tier=%s): "
655- "cache=%.2f GB (floor=%.2f GB), keep=%s/%s (floor=%d), "
656- "prefill_chunk=%d (floor=%d). Re-raising original CUDA OOM." ,
657- attempt_idx , dx .tier ,
658- self .max_cache_bytes / 1e9 , self .MIN_CACHE_BYTES / 1e9 ,
711+ "OOM recovery EXHAUSTED at attempt %d (tier=%s, action=%s): "
712+ "cache_used=%.0f MiB cache_budget=%.0f MiB (floor=%.0f MiB), "
713+ "keep=%s/%s (floor=%d), prefill_chunk=%d (floor=%d). "
714+ "The diagnosed knob is at its floor and no fallback could help. "
715+ "Re-raising original CUDA OOM." ,
716+ attempt_idx , dx .tier , dx .action ,
717+ dx .cache_used_mb , dx .cache_budget_mb ,
718+ self .MIN_CACHE_BYTES / (1024.0 ** 2 ),
659719 self .keep_first_k , self .keep_last_k , self .MIN_KEEP ,
660720 self .prefill_chunk_size , self .MIN_PREFILL_CHUNK ,
661721 )
@@ -678,11 +738,20 @@ def _apply_diagnosis(
678738
679739 action = change ["action" ]
680740 if action == "lower_cache" :
681- summary = (
682- f"lower_cache: max_cache_bytes "
683- f"{ change ['old_max_cache_bytes' ] / 1e9 :.2f} GB → "
684- f"{ change ['new_max_cache_bytes' ] / 1e9 :.2f} GB"
685- )
741+ old_gb = change ["old_max_cache_bytes" ] / 1e9
742+ new_gb = change ["new_max_cache_bytes" ] / 1e9
743+ evicted_mb = change .get ("evicted_bytes" , 0 ) / (1024.0 ** 2 )
744+ if old_gb == new_gb :
745+ # Budget at floor — pure eviction step.
746+ summary = (
747+ f"lower_cache: evict-only, freed { evicted_mb :.0f} MiB "
748+ f"(budget pinned at { new_gb :.2f} GB floor)"
749+ )
750+ else :
751+ summary = (
752+ f"lower_cache: max_cache_bytes { old_gb :.2f} GB → "
753+ f"{ new_gb :.2f} GB (freed { evicted_mb :.0f} MiB)"
754+ )
686755 shrink = self .CACHE_SHRINK
687756 elif action == "lower_streaming" :
688757 summary = (
0 commit comments