Skip to content

Commit 09186f8

Browse files
committed
OOM Recovery: reduce prefill chunk size
1 parent 31be9a7 commit 09186f8

3 files changed

Lines changed: 163 additions & 47 deletions

File tree

benchmarks_and_experiments/sharegpt_3way/run_kvboost.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def main():
273273
initial_keep_first_k=args.keep_first_k if args.awq_streaming else None,
274274
initial_keep_last_k=args.keep_last_k if args.awq_streaming else None,
275275
streaming_enabled=args.awq_streaming,
276+
initial_prefill_chunk_size=getattr(args, "prefill_chunk_size", 0),
276277
max_retries=args.oom_max_retries,
277278
)
278279

src/kvboost/oom_recovery.py

Lines changed: 161 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99
* KV cache "low" → lower streaming residency (``keep_first_k`` /
1010
``keep_last_k``). Layers drop out of VRAM, costing streaming overhead
1111
but unblocking the request.
12+
* Cache nearly empty AND streaming disabled → lower ``prefill_chunk_size``
13+
on the engine. Cuts peak prefill activation memory (the real culprit
14+
for long prompts, where neither knob above can help).
1215
13-
A repeated OOM with the *same* knob just adjusted flips to the other knob,
16+
A repeated OOM with the *same* knob just adjusted flips to the next knob,
1417
so a single persistent bottleneck still gets fully addressed. Each knob has
15-
a floor (``MIN_CACHE_BYTES`` / ``MIN_KEEP``); when both are exhausted the
16-
original exception re-raises.
18+
a floor (``MIN_CACHE_BYTES`` / ``MIN_KEEP`` / ``MIN_PREFILL_CHUNK``); when
19+
all are exhausted the original exception re-raises.
1720
1821
This module is used by:
1922
* the kvboost inference server (``kvboost.server.engine_worker``)
@@ -45,20 +48,24 @@ class Diagnosis:
4548
once per request to avoid masking real OOMs.
4649
* ``"cache_dominant"`` — cache holds ≥ 1.5× tried-alloc; eviction will
4750
free more than the failed allocation needs.
51+
* ``"prefill_bound"`` — cache holds < 0.5× tried-alloc AND streaming
52+
can't help (disabled or exhausted); the OOM is prefill activation
53+
memory, not KV cache. Lower ``prefill_chunk_size`` on the engine.
4854
* ``"residency_bound"`` — cache holds < 0.5× tried-alloc; cache can't
4955
plausibly close the gap, must lower streaming residency.
50-
* ``"mixed"`` — neither knob is the obvious culprit; pick the one with
56+
* ``"mixed"`` — no knob is the obvious culprit; pick the one with
5157
more remaining headroom (% from floor), cache wins ties (cheaper).
5258
"""
5359

5460
tier: str
55-
action: str # "empty_cache_only" | "lower_cache" | "lower_streaming"
61+
action: str # "empty_cache_only" | "lower_cache" | "lower_streaming" | "lower_prefill_chunk"
5662
reason: str
5763
parsed_oom: Dict[str, Optional[float]] = field(default_factory=dict)
5864
cache_used_mb: float = 0.0
5965
cache_budget_mb: float = 0.0
6066
cache_headroom_frac: float = 0.0
6167
stream_headroom_frac: float = 0.0
68+
prefill_headroom_frac: float = 0.0
6269

6370

6471
# Regexes for torch's CUDA OOM message format. Defensive: any field that
@@ -126,8 +133,13 @@ class OOMRecovery:
126133

127134
CACHE_SHRINK = 0.7
128135
STREAM_SHRINK = 0.5
136+
PREFILL_SHRINK = 0.5
129137
MIN_CACHE_BYTES = int(2.5e8) # 250 MB floor
130138
MIN_KEEP = 1 # absolute floor — at least 1 resident layer each side
139+
MIN_PREFILL_CHUNK = 32 # absolute floor for prefill chunk size
140+
# Starting value when the engine was launched with prefill_chunk_size=0
141+
# (single-shot prefill) and we need to enable chunking on first OOM.
142+
INITIAL_PREFILL_CHUNK_ON_OOM = 2048
131143
SAFETY_CAP = 16 # absolute attempt cap to avoid pathological loops
132144

133145
def __init__(
@@ -138,6 +150,7 @@ def __init__(
138150
initial_keep_first_k: Optional[int],
139151
initial_keep_last_k: Optional[int],
140152
streaming_enabled: bool,
153+
initial_prefill_chunk_size: int = 0,
141154
max_retries: Optional[int] = None,
142155
):
143156
"""Initialise recovery state.
@@ -152,6 +165,7 @@ def __init__(
152165
self.keep_first_k = initial_keep_first_k
153166
self.keep_last_k = initial_keep_last_k
154167
self.streaming_enabled = streaming_enabled
168+
self.prefill_chunk_size = int(initial_prefill_chunk_size)
155169
self.max_retries = max_retries
156170
self.events: List[Dict[str, Any]] = []
157171
# Lifetime counter (across all calls) for fragmentation diagnoses;
@@ -195,6 +209,16 @@ def _stream_headroom_frac(self) -> float:
195209
return 0.0
196210
return (self.keep_first_k - self.MIN_KEEP) / self.keep_first_k
197211

212+
def _prefill_headroom_frac(self) -> float:
213+
"""How much room is left to shrink prefill chunk size. A current value
214+
of 0 (single-shot prefill) counts as full headroom — we can switch to
215+
``INITIAL_PREFILL_CHUNK_ON_OOM`` on first OOM."""
216+
if self.prefill_chunk_size == 0:
217+
return 1.0
218+
if self.prefill_chunk_size <= self.MIN_PREFILL_CHUNK:
219+
return 0.0
220+
return (self.prefill_chunk_size - self.MIN_PREFILL_CHUNK) / self.prefill_chunk_size
221+
198222
# ── 4-tier diagnosis ──
199223
def _diagnose(self, err: BaseException, *, allow_fragmentation_tier: bool) -> Diagnosis:
200224
"""Classify an OOM into one of four tiers and pick the cheapest action.
@@ -208,6 +232,7 @@ def _diagnose(self, err: BaseException, *, allow_fragmentation_tier: bool) -> Di
208232
cache_budget_mb = self.max_cache_bytes / (1024.0 ** 2)
209233
cache_h = self._cache_headroom_frac()
210234
stream_h = self._stream_headroom_frac()
235+
prefill_h = self._prefill_headroom_frac()
211236

212237
tried = parsed["tried_alloc_mb"]
213238
frag = parsed["reserved_unalloc_mb"]
@@ -230,25 +255,39 @@ def _diagnose(self, err: BaseException, *, allow_fragmentation_tier: bool) -> Di
230255
cache_budget_mb=cache_budget_mb,
231256
cache_headroom_frac=cache_h,
232257
stream_headroom_frac=stream_h,
258+
prefill_headroom_frac=prefill_h,
233259
)
234260

235261
# ── No tried_alloc → fall back to a simple budget-fraction heuristic ──
236-
# Old behaviour: if cache > HIGH_FRAC × budget, lower cache; else streaming.
262+
# Cache-empty + parseless OOM = almost always prefill activation.
263+
# Pick prefill_chunk first if it has headroom, then fall back to the
264+
# legacy cache-fraction heuristic.
237265
if tried is None:
238-
cache_dominant = cache_used_mb > 0.5 * cache_budget_mb
239-
action = "lower_cache" if cache_dominant or not self.streaming_enabled else "lower_streaming"
266+
cache_nearly_empty = cache_used_mb < 0.1 * max(cache_budget_mb, 1.0)
267+
if cache_nearly_empty and prefill_h > 0:
268+
action = "lower_prefill_chunk"
269+
reason = (
270+
f"OOM message unparseable; cache is nearly empty "
271+
f"({cache_used_mb:.0f}/{cache_budget_mb:.0f} MiB) "
272+
f"→ lowering prefill chunk"
273+
)
274+
else:
275+
cache_dominant = cache_used_mb > 0.5 * cache_budget_mb
276+
action = "lower_cache" if cache_dominant or not self.streaming_enabled else "lower_streaming"
277+
reason = (
278+
f"OOM message unparseable; legacy heuristic: cache_used "
279+
f"{cache_used_mb:.0f}/{cache_budget_mb:.0f} MiB → {action}"
280+
)
240281
return Diagnosis(
241282
tier="mixed",
242283
action=action,
243-
reason=(
244-
f"OOM message unparseable; legacy heuristic: cache_used "
245-
f"{cache_used_mb:.0f}/{cache_budget_mb:.0f} MiB → {action}"
246-
),
284+
reason=reason,
247285
parsed_oom=parsed,
248286
cache_used_mb=cache_used_mb,
249287
cache_budget_mb=cache_budget_mb,
250288
cache_headroom_frac=cache_h,
251289
stream_headroom_frac=stream_h,
290+
prefill_headroom_frac=prefill_h,
252291
)
253292

254293
# ── Tier 2: Cache dominant ──
@@ -266,9 +305,33 @@ def _diagnose(self, err: BaseException, *, allow_fragmentation_tier: bool) -> Di
266305
cache_budget_mb=cache_budget_mb,
267306
cache_headroom_frac=cache_h,
268307
stream_headroom_frac=stream_h,
308+
prefill_headroom_frac=prefill_h,
309+
)
310+
311+
# ── Tier 3a: Prefill bound (cache can't help; streaming can't either) ──
312+
# cache_used << tried and either streaming is disabled or already at
313+
# its floor → the OOM is almost certainly prefill activation memory.
314+
# Lower prefill_chunk_size on the engine to cap peak per-step memory.
315+
if (cache_used_mb < self.TIER_RESIDENCY_BOUND * tried
316+
and prefill_h > 0
317+
and (not self.streaming_enabled or stream_h <= 0.0)):
318+
return Diagnosis(
319+
tier="prefill_bound",
320+
action="lower_prefill_chunk",
321+
reason=(
322+
f"cache_used {cache_used_mb:.0f} MiB << tried-alloc "
323+
f"{tried:.0f} MiB and streaming can't help — OOM is "
324+
f"prefill activation, lowering chunk size"
325+
),
326+
parsed_oom=parsed,
327+
cache_used_mb=cache_used_mb,
328+
cache_budget_mb=cache_budget_mb,
329+
cache_headroom_frac=cache_h,
330+
stream_headroom_frac=stream_h,
331+
prefill_headroom_frac=prefill_h,
269332
)
270333

271-
# ── Tier 3: Residency bound ──
334+
# ── Tier 3b: Residency bound ──
272335
if cache_used_mb < self.TIER_RESIDENCY_BOUND * tried and self.streaming_enabled:
273336
return Diagnosis(
274337
tier="residency_bound",
@@ -283,25 +346,29 @@ def _diagnose(self, err: BaseException, *, allow_fragmentation_tier: bool) -> Di
283346
cache_budget_mb=cache_budget_mb,
284347
cache_headroom_frac=cache_h,
285348
stream_headroom_frac=stream_h,
349+
prefill_headroom_frac=prefill_h,
286350
)
287351

288352
# ── Tier 4: Mixed ──
289-
# Cache wins ties because eviction is cheaper (and stream headroom is
290-
# -1.0 when streaming is off, so cache also wins by default there).
291-
action = "lower_streaming" if stream_h > cache_h else "lower_cache"
353+
# Highest remaining headroom wins; cache breaks ties (cheapest).
354+
candidates = [("lower_cache", cache_h), ("lower_streaming", stream_h),
355+
("lower_prefill_chunk", prefill_h)]
356+
action = max(candidates, key=lambda kv: kv[1])[0]
292357
return Diagnosis(
293358
tier="mixed",
294359
action=action,
295360
reason=(
296361
f"ambiguous (cache_used {cache_used_mb:.0f} MiB ≈ tried "
297362
f"{tried:.0f} MiB): cache_headroom={cache_h:.0%}, "
298-
f"stream_headroom={stream_h:.0%}{action}"
363+
f"stream_headroom={stream_h:.0%}, "
364+
f"prefill_headroom={prefill_h:.0%}{action}"
299365
),
300366
parsed_oom=parsed,
301367
cache_used_mb=cache_used_mb,
302368
cache_budget_mb=cache_budget_mb,
303369
cache_headroom_frac=cache_h,
304370
stream_headroom_frac=stream_h,
371+
prefill_headroom_frac=prefill_h,
305372
)
306373

307374
# ── Knob adjusters ──
@@ -390,6 +457,31 @@ def _lower_streaming(self) -> Optional[Dict[str, Any]]:
390457
"keep_last_k": new_last,
391458
}
392459

460+
def _lower_prefill_chunk(self) -> Optional[Dict[str, Any]]:
461+
"""Reduce engine.prefill_chunk_size to cap peak prefill activation
462+
memory. From 0 (single-shot) we jump to ``INITIAL_PREFILL_CHUNK_ON_OOM``;
463+
otherwise we halve down to ``MIN_PREFILL_CHUNK``. Returns None when the
464+
engine doesn't expose the knob or we've hit the floor."""
465+
if not hasattr(self.engine, "prefill_chunk_size"):
466+
return None
467+
old = self.prefill_chunk_size
468+
if old == 0:
469+
new = self.INITIAL_PREFILL_CHUNK_ON_OOM
470+
else:
471+
new = max(self.MIN_PREFILL_CHUNK, int(old * self.PREFILL_SHRINK))
472+
if new >= old:
473+
return None
474+
try:
475+
self.engine.prefill_chunk_size = new
476+
except Exception:
477+
return None
478+
self.prefill_chunk_size = new
479+
return {
480+
"action": "lower_prefill_chunk",
481+
"old_prefill_chunk_size": old,
482+
"new_prefill_chunk_size": new,
483+
}
484+
393485
# ── Driver ──
394486
def attempt(
395487
self,
@@ -527,32 +619,45 @@ def _apply_diagnosis(
527619
# ── Tiers 2-4: knob change ──
528620
# For cache-action tiers, _lower_cache calls cm.clear() itself, so we
529621
# don't pre-emptively reset_cache (that would lose chunks before the
530-
# knob even applies). For residency tiers, we keep the cache intact —
531-
# losing it doesn't help residency pressure and gives up valuable reuse.
622+
# knob even applies). For residency / prefill tiers, we keep the cache
623+
# intact — losing it doesn't help and gives up valuable reuse.
624+
knobs = {
625+
"lower_cache": self._lower_cache,
626+
"lower_streaming": self._lower_streaming,
627+
"lower_prefill_chunk": self._lower_prefill_chunk,
628+
}
532629
primary_name = dx.action
533-
secondary_name = "lower_streaming" if primary_name == "lower_cache" else "lower_cache"
534-
primary = self._lower_cache if primary_name == "lower_cache" else self._lower_streaming
535-
secondary = self._lower_streaming if primary_name == "lower_cache" else self._lower_cache
536-
537-
change = primary()
538-
flipped = False
539-
if change is None:
540-
log.warning(
541-
"OOM recovery: diagnosed action '%s' unavailable (floor reached or disabled); "
542-
"falling back to '%s'.",
543-
primary_name, secondary_name,
544-
)
545-
change = secondary()
546-
flipped = True
630+
# Fallback order: try the diagnosed knob first, then the others in a
631+
# fixed order (cache → prefill → streaming). Prefill before streaming
632+
# because most cache-empty OOMs are prefill activations, not weights.
633+
fallback_chain = [primary_name] + [
634+
k for k in ("lower_cache", "lower_prefill_chunk", "lower_streaming")
635+
if k != primary_name
636+
]
637+
638+
change = None
639+
flipped_to = None
640+
for i, knob_name in enumerate(fallback_chain):
641+
change = knobs[knob_name]()
642+
if change is not None:
643+
if i > 0:
644+
flipped_to = knob_name
645+
log.warning(
646+
"OOM recovery: diagnosed action '%s' unavailable; "
647+
"falling back to '%s'.",
648+
primary_name, knob_name,
649+
)
650+
break
547651

548652
if change is None:
549653
log.error(
550654
"OOM recovery EXHAUSTED at attempt %d (tier=%s): "
551-
"cache=%.2f GB (floor=%.2f GB), keep=%s/%s (floor=%d). "
552-
"Re-raising original CUDA OOM.",
655+
"cache=%.2f GB (floor=%.2f GB), keep=%s/%s (floor=%d), "
656+
"prefill_chunk=%d (floor=%d). Re-raising original CUDA OOM.",
553657
attempt_idx, dx.tier,
554658
self.max_cache_bytes / 1e9, self.MIN_CACHE_BYTES / 1e9,
555659
self.keep_first_k, self.keep_last_k, self.MIN_KEEP,
660+
self.prefill_chunk_size, self.MIN_PREFILL_CHUNK,
556661
)
557662
return None
558663

@@ -565,31 +670,39 @@ def _apply_diagnosis(
565670
change["cache_budget_mb"] = dx.cache_budget_mb
566671
change["cache_headroom_frac"] = dx.cache_headroom_frac
567672
change["stream_headroom_frac"] = dx.stream_headroom_frac
568-
change["flipped_to_secondary"] = flipped
673+
change["prefill_headroom_frac"] = dx.prefill_headroom_frac
674+
change["flipped_to_secondary"] = flipped_to is not None
675+
if flipped_to is not None:
676+
change["flipped_to"] = flipped_to
569677
self.events.append(change)
570678

571-
if change["action"] == "lower_cache":
679+
action = change["action"]
680+
if action == "lower_cache":
572681
summary = (
573682
f"lower_cache: max_cache_bytes "
574683
f"{change['old_max_cache_bytes'] / 1e9:.2f} GB → "
575684
f"{change['new_max_cache_bytes'] / 1e9:.2f} GB"
576685
)
577-
change["summary"] = summary
578-
log.warning(
579-
"OOM recovery → %s (×%.2f, tier=%s, attempt=%d). Will retry.",
580-
summary, self.CACHE_SHRINK, dx.tier, attempt_idx,
581-
)
582-
else:
686+
shrink = self.CACHE_SHRINK
687+
elif action == "lower_streaming":
583688
summary = (
584689
f"lower_streaming: keep_first_k {change['old_keep_first_k']}→"
585690
f"{change['keep_first_k']}, keep_last_k "
586691
f"{change['old_keep_last_k'] or 0}{change['keep_last_k']}"
587692
)
588-
change["summary"] = summary
589-
log.warning(
590-
"OOM recovery → %s (×%.2f, tier=%s, attempt=%d). Will retry.",
591-
summary, self.STREAM_SHRINK, dx.tier, attempt_idx,
693+
shrink = self.STREAM_SHRINK
694+
else: # lower_prefill_chunk
695+
summary = (
696+
f"lower_prefill_chunk: prefill_chunk_size "
697+
f"{change['old_prefill_chunk_size']} → "
698+
f"{change['new_prefill_chunk_size']}"
592699
)
700+
shrink = self.PREFILL_SHRINK
701+
change["summary"] = summary
702+
log.warning(
703+
"OOM recovery → %s (×%.2f, tier=%s, attempt=%d). Will retry.",
704+
summary, shrink, dx.tier, attempt_idx,
705+
)
593706
return change
594707

595708
def snapshot(self) -> Dict[str, Any]:
@@ -600,5 +713,6 @@ def snapshot(self) -> Dict[str, Any]:
600713
"max_cache_bytes": self.max_cache_bytes,
601714
"keep_first_k": self.keep_first_k,
602715
"keep_last_k": self.keep_last_k,
716+
"prefill_chunk_size": self.prefill_chunk_size,
603717
"events": list(self.events),
604718
}

src/kvboost/server/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ def main():
460460
initial_keep_first_k=args.keep_first_k if args.awq_streaming else None,
461461
initial_keep_last_k=args.keep_last_k if args.awq_streaming else None,
462462
streaming_enabled=args.awq_streaming,
463+
initial_prefill_chunk_size=args.prefill_chunk_size,
463464
max_retries=args.oom_max_retries,
464465
)
465466
log.info(

0 commit comments

Comments
 (0)