Skip to content

Commit 5d508ea

Browse files
committed
enforce new OOM policy
1 parent 09186f8 commit 5d508ea

1 file changed

Lines changed: 132 additions & 63 deletions

File tree

src/kvboost/oom_recovery.py

Lines changed: 132 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,16 @@
1313
on the engine. Cuts peak prefill activation memory (the real culprit
1414
for long prompts, where neither knob above can help).
1515
16-
A repeated OOM with the *same* knob just adjusted flips to the next knob,
17-
so a single persistent bottleneck still gets fully addressed. Each knob has
18-
a floor (``MIN_CACHE_BYTES`` / ``MIN_KEEP`` / ``MIN_PREFILL_CHUNK``); when
19-
all are exhausted the original exception re-raises.
16+
Each tier is authoritative: cache-dominant OOMs touch *only* the cache,
17+
prefill-bound OOMs touch *only* prefill_chunk_size, residency-bound OOMs
18+
touch *only* streaming residency. Falling back to a different knob would
19+
damage unrelated state (e.g. dropping KV reuse when the real problem was
20+
prompt length) without addressing the cause. The only tier with cross-knob
21+
fallback is ``mixed`` — by definition genuinely ambiguous — and even there
22+
candidates are filtered to those that can plausibly help. A repeated OOM
23+
gets a fresh diagnosis, so a shifting bottleneck still ends up at the
24+
right knob across attempts. When the diagnosed knob is at its floor, the
25+
original exception re-raises.
2026
2127
This module is used by:
2228
* the kvboost inference server (``kvboost.server.engine_worker``)
@@ -376,29 +382,46 @@ def _diagnose(self, err: BaseException, *, allow_fragmentation_tier: bool) -> Di
376382
# decrease is possible (clamped at floor, or initial value was already
377383
# below the floor) is the signal to the caller that this knob is exhausted.
378384
def _lower_cache(self) -> Optional[Dict[str, Any]]:
385+
"""Free KV cache memory. The freeing happens via ``cm.clear()``;
386+
lowering the budget is just a prophylactic so it doesn't refill to
387+
the same level on the next request.
388+
389+
Returns None when this knob can't help — either the budget is at
390+
floor AND the cache is empty (nothing to clear, no room to shrink),
391+
or ``cm.clear()`` is unavailable. When the budget is at floor but
392+
the cache is still populated, we still evict (budget unchanged).
393+
"""
394+
cm = self.engine.cache_manager
379395
old_bytes = self.max_cache_bytes
380-
# Apply shrink, then clamp UP to the floor so we never go below it,
381-
# then require strict decrease relative to old (which prevents the
382-
# floor clamp from ever raising the value).
383396
new_bytes = max(self.MIN_CACHE_BYTES, int(old_bytes * self.CACHE_SHRINK))
384-
if new_bytes >= old_bytes:
397+
cache_used = self._cache_bytes_used()
398+
399+
budget_can_shrink = new_bytes < old_bytes
400+
evict_can_help = cache_used > 0
401+
402+
if not budget_can_shrink and not evict_can_help:
385403
return None
386-
cm = self.engine.cache_manager
387-
for attr in ("max_cache_bytes", "_max_cache_bytes", "max_bytes"):
388-
if hasattr(cm, attr):
389-
try:
390-
setattr(cm, attr, new_bytes)
391-
except Exception:
392-
pass
393-
self.max_cache_bytes = new_bytes
394-
try:
395-
cm.clear()
396-
except Exception:
397-
pass
404+
405+
if budget_can_shrink:
406+
for attr in ("max_cache_bytes", "_max_cache_bytes", "max_bytes"):
407+
if hasattr(cm, attr):
408+
try:
409+
setattr(cm, attr, new_bytes)
410+
except Exception:
411+
pass
412+
self.max_cache_bytes = new_bytes
413+
414+
if evict_can_help:
415+
try:
416+
cm.clear()
417+
except Exception:
418+
pass
419+
398420
return {
399421
"action": "lower_cache",
400422
"old_max_cache_bytes": old_bytes,
401-
"new_max_cache_bytes": new_bytes,
423+
"new_max_cache_bytes": self.max_cache_bytes,
424+
"evicted_bytes": cache_used if evict_can_help else 0,
402425
}
403426

404427
def _lower_streaming(self) -> Optional[Dict[str, Any]]:
@@ -492,16 +515,17 @@ def attempt(
492515
) -> Any:
493516
"""Call ``fn(*args, **kwargs)`` with OOM-aware retry.
494517
495-
On every OOM we run ``_diagnose(err)`` to pick one of four tiers,
496-
then ``_apply_diagnosis(...)`` to execute it. Tier 1 (fragmentation)
497-
is allowed at most once per request — repeated "fragmentation"
498-
diagnoses would mask a real residency/cache OOM. Tiers 2-4 each
499-
shrink a knob (or skip when both are exhausted).
518+
On every OOM we run ``_diagnose(err)`` to pick a tier, then
519+
``_apply_diagnosis(...)`` to execute it. Tier 1 (fragmentation) is
520+
allowed at most once per request — repeated "fragmentation"
521+
diagnoses would mask a real OOM. Non-``mixed`` tiers each touch a
522+
single specific knob (no cross-knob fallback); ``mixed`` tries
523+
candidates filtered by "can plausibly help" in headroom order.
500524
501525
Loop terminates when:
502526
* ``fn`` returns normally
503-
* Both knobs hit floor AND fragmentation has been used (no remaining
504-
recovery action)
527+
* The diagnosed knob is at its floor (or, in ``mixed``, every
528+
plausibly-helpful knob is exhausted) — re-raise.
505529
* ``can_retry()`` returns False (mid-stream after partial output) —
506530
the knob still gets adjusted for the next request, then re-raise.
507531
* The safety cap is reached.
@@ -539,15 +563,8 @@ def attempt(
539563

540564
change = self._apply_diagnosis(diagnosis, attempt_idx)
541565
if change is None:
542-
# All recovery paths exhausted (knobs at floor, fragmentation already used).
543-
log.error(
544-
"OOM recovery cannot reduce further: cache=%.2f GB (floor=%.2f GB), "
545-
"keep_first_k=%s keep_last_k=%s (floor=%d), fragmentation_used=%s. "
546-
"Re-raising.",
547-
self.max_cache_bytes / 1e9, self.MIN_CACHE_BYTES / 1e9,
548-
self.keep_first_k, self.keep_last_k, self.MIN_KEEP,
549-
fragmentation_used,
550-
)
566+
# Diagnosed knob (or every plausibly-helpful knob in `mixed`)
567+
# is at floor. _apply_diagnosis already logged the details.
551568
break
552569
oom_events.append(change)
553570

@@ -573,6 +590,41 @@ def attempt(
573590
assert last_err is not None
574591
raise last_err
575592

593+
def _mixed_candidates(self, dx: Diagnosis) -> List[str]:
594+
"""For the ``mixed`` tier, return knob names ordered by headroom,
595+
keeping only knobs that can plausibly help this OOM.
596+
597+
Filter rules:
598+
* ``lower_cache`` — kept only when eviction would free at least the
599+
failed allocation. ``cm.clear()`` drops at most ``cache_used_mb``;
600+
if that's already less than ``tried_alloc_mb``, the OOM persists
601+
after the clear, so the knob is provably insufficient. When
602+
``tried`` is unknown, fall back to a 50 MiB absolute threshold.
603+
* ``lower_prefill_chunk`` — kept whenever there's any room to shrink
604+
(``prefill_headroom > 0``).
605+
* ``lower_streaming`` — kept whenever streaming is enabled and has
606+
residency headroom (``stream_headroom > 0``).
607+
608+
The result is sorted by headroom descending; ties break in the order
609+
cache → prefill → streaming (cheapest first).
610+
"""
611+
tried = dx.parsed_oom.get("tried_alloc_mb")
612+
cache_can_help = (
613+
dx.cache_used_mb >= tried if tried is not None
614+
else dx.cache_used_mb >= 50.0
615+
)
616+
617+
candidates: List[tuple] = []
618+
if cache_can_help:
619+
candidates.append(("lower_cache", dx.cache_headroom_frac, 0))
620+
if dx.prefill_headroom_frac > 0:
621+
candidates.append(("lower_prefill_chunk", dx.prefill_headroom_frac, 1))
622+
if dx.stream_headroom_frac > 0:
623+
candidates.append(("lower_streaming", dx.stream_headroom_frac, 2))
624+
625+
candidates.sort(key=lambda kv: (-kv[1], kv[2]))
626+
return [name for name, _, _ in candidates]
627+
576628
# ── Apply a Diagnosis ──
577629
def _apply_diagnosis(
578630
self, dx: Diagnosis, attempt_idx: int,
@@ -617,45 +669,53 @@ def _apply_diagnosis(
617669
return event
618670

619671
# ── Tiers 2-4: knob change ──
620-
# For cache-action tiers, _lower_cache calls cm.clear() itself, so we
621-
# don't pre-emptively reset_cache (that would lose chunks before the
622-
# knob even applies). For residency / prefill tiers, we keep the cache
623-
# intact — losing it doesn't help and gives up valuable reuse.
672+
# Each non-`mixed` tier is authoritative: it identified the actual
673+
# bottleneck, so we touch *only* that knob. Falling back to a
674+
# different knob would damage unrelated state (e.g. dropping KV
675+
# reuse when the real problem was prompt length) without addressing
676+
# the OOM cause. If the diagnosed knob is exhausted, recovery is
677+
# genuinely out of moves for this tier — re-raise.
678+
#
679+
# `mixed` is the only tier with real ambiguity, so it's the only
680+
# tier that tries multiple knobs. Even there we filter candidates
681+
# to those that can *plausibly* help — a cache-eviction step that
682+
# would free fewer bytes than the failed allocation tried is
683+
# provably useless and gets skipped.
624684
knobs = {
625685
"lower_cache": self._lower_cache,
626686
"lower_streaming": self._lower_streaming,
627687
"lower_prefill_chunk": self._lower_prefill_chunk,
628688
}
629-
primary_name = dx.action
630-
# Fallback order: try the diagnosed knob first, then the others in a
631-
# fixed order (cache → prefill → streaming). Prefill before streaming
632-
# because most cache-empty OOMs are prefill activations, not weights.
633-
fallback_chain = [primary_name] + [
634-
k for k in ("lower_cache", "lower_prefill_chunk", "lower_streaming")
635-
if k != primary_name
636-
]
689+
690+
if dx.tier == "mixed":
691+
ordered = self._mixed_candidates(dx)
692+
else:
693+
ordered = [dx.action]
637694

638695
change = None
639696
flipped_to = None
640-
for i, knob_name in enumerate(fallback_chain):
697+
for i, knob_name in enumerate(ordered):
641698
change = knobs[knob_name]()
642699
if change is not None:
643700
if i > 0:
644701
flipped_to = knob_name
645702
log.warning(
646-
"OOM recovery: diagnosed action '%s' unavailable; "
647-
"falling back to '%s'.",
648-
primary_name, knob_name,
703+
"OOM recovery (mixed): primary '%s' unavailable; "
704+
"using '%s' instead.",
705+
dx.action, knob_name,
649706
)
650707
break
651708

652709
if change is None:
653710
log.error(
654-
"OOM recovery EXHAUSTED at attempt %d (tier=%s): "
655-
"cache=%.2f GB (floor=%.2f GB), keep=%s/%s (floor=%d), "
656-
"prefill_chunk=%d (floor=%d). Re-raising original CUDA OOM.",
657-
attempt_idx, dx.tier,
658-
self.max_cache_bytes / 1e9, self.MIN_CACHE_BYTES / 1e9,
711+
"OOM recovery EXHAUSTED at attempt %d (tier=%s, action=%s): "
712+
"cache_used=%.0f MiB cache_budget=%.0f MiB (floor=%.0f MiB), "
713+
"keep=%s/%s (floor=%d), prefill_chunk=%d (floor=%d). "
714+
"The diagnosed knob is at its floor and no fallback could help. "
715+
"Re-raising original CUDA OOM.",
716+
attempt_idx, dx.tier, dx.action,
717+
dx.cache_used_mb, dx.cache_budget_mb,
718+
self.MIN_CACHE_BYTES / (1024.0 ** 2),
659719
self.keep_first_k, self.keep_last_k, self.MIN_KEEP,
660720
self.prefill_chunk_size, self.MIN_PREFILL_CHUNK,
661721
)
@@ -678,11 +738,20 @@ def _apply_diagnosis(
678738

679739
action = change["action"]
680740
if action == "lower_cache":
681-
summary = (
682-
f"lower_cache: max_cache_bytes "
683-
f"{change['old_max_cache_bytes'] / 1e9:.2f} GB → "
684-
f"{change['new_max_cache_bytes'] / 1e9:.2f} GB"
685-
)
741+
old_gb = change["old_max_cache_bytes"] / 1e9
742+
new_gb = change["new_max_cache_bytes"] / 1e9
743+
evicted_mb = change.get("evicted_bytes", 0) / (1024.0 ** 2)
744+
if old_gb == new_gb:
745+
# Budget at floor — pure eviction step.
746+
summary = (
747+
f"lower_cache: evict-only, freed {evicted_mb:.0f} MiB "
748+
f"(budget pinned at {new_gb:.2f} GB floor)"
749+
)
750+
else:
751+
summary = (
752+
f"lower_cache: max_cache_bytes {old_gb:.2f} GB → "
753+
f"{new_gb:.2f} GB (freed {evicted_mb:.0f} MiB)"
754+
)
686755
shrink = self.CACHE_SHRINK
687756
elif action == "lower_streaming":
688757
summary = (

0 commit comments

Comments
 (0)