Skip to content

Commit f16bced

Browse files
committed
OOM Recovery tokens
1 parent 60cf88f commit f16bced

24 files changed

Lines changed: 3784 additions & 48 deletions

src/kvboost/engine.py

Lines changed: 128 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030
import time
3131
from contextlib import contextmanager
3232
from dataclasses import dataclass, field
33-
from typing import Callable, Dict, List, Optional, Set, Tuple
33+
from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple
34+
35+
if TYPE_CHECKING:
36+
from .speculative.tree.config import TreeSpeculativeConfig
3437

3538
import torch
3639
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -101,6 +104,11 @@ def __init__(
101104
prefill_chunk_size: int = 0,
102105
# Speculative decoding (None = disabled, baseline decode path)
103106
speculative_config: Optional["SpeculativeConfig"] = None,
107+
# Tree speculative — SpecBlock-inspired, may coexist with flat
108+
tree_speculative_config: Optional["TreeSpeculativeConfig"] = None,
109+
# Cost coefficients (probed at server startup) for cost-aware
110+
# tree shape + mode selection. None = degraded mode (defaults).
111+
cost_coefficients: Any = None,
104112
):
105113
if device is None:
106114
device = default_device()
@@ -160,30 +168,105 @@ def __init__(
160168

161169
# Speculative decoding (decode-phase orthogonal to recompute_strategy).
162170
# CacheBlend handles prefill; speculative handles decode. They stack.
171+
# Two flavors: flat (token-by-token K draft) and tree (SpecBlock-
172+
# inspired). Both may be present; the bridge / ModeSelector picks
173+
# per-request when so.
163174
self.speculative_config = speculative_config
175+
self.tree_speculative_config = tree_speculative_config
176+
self.cost_coefficients = cost_coefficients
164177
self.speculative_engine = None
165-
if speculative_config is not None:
166-
speculative_config.validate()
178+
self.tree_speculative_engine = None
179+
self.mode_selector = None
180+
181+
need_draft = (
182+
speculative_config is not None
183+
or tree_speculative_config is not None
184+
)
185+
if need_draft:
167186
from .speculative.draft import DraftModel
168-
from .speculative.engine import SpeculativeEngine
169187
from .speculative.stats import SpeculativeStats
170188
from .speculative.verifier import TargetVerifier
171189

190+
# Validate whichever configs are present. The DraftModel
191+
# itself needs a flat-style ``SpeculativeConfig`` so its
192+
# model-load path stays one code path (the tree config
193+
# doesn't carry draft_model_id / draft_streaming_config; if
194+
# only tree is wired we still rely on the flat config for
195+
# the drafter handle).
196+
if speculative_config is not None:
197+
speculative_config.validate()
198+
if tree_speculative_config is not None:
199+
tree_speculative_config.validate()
200+
201+
# The drafter is shared across flat + tree.
202+
if speculative_config is None:
203+
raise ValueError(
204+
"tree_speculative_config requires a flat "
205+
"SpeculativeConfig (drafter model handle); pass "
206+
"both."
207+
)
208+
209+
self._speculative_stats = SpeculativeStats()
172210
log.info(
173-
"Speculative decoding enabled: %s",
211+
"Speculative decoding enabled: flat=%s tree=%s",
174212
speculative_config.summary(),
213+
tree_speculative_config.summary()
214+
if tree_speculative_config else "off",
175215
)
176-
self._speculative_stats = SpeculativeStats()
177216
draft = DraftModel(
178217
speculative_config, target_tokenizer=tokenizer
179218
)
180219
verifier = TargetVerifier(self.model, device=device)
220+
221+
# Flat engine: existing path, unchanged.
222+
from .speculative.engine import SpeculativeEngine
181223
self.speculative_engine = SpeculativeEngine(
182224
cfg=speculative_config,
183225
target_verifier=verifier,
184226
draft_model=draft,
185227
stats=self._speculative_stats,
186228
)
229+
230+
# Tree engine: only when its config is provided.
231+
if tree_speculative_config is not None:
232+
from .speculative.tree.engine import TreeSpeculativeEngine
233+
234+
target_step_ms = (
235+
cost_coefficients.step_latency_ms
236+
if cost_coefficients is not None else 50.0
237+
)
238+
# Draft step latency is unknown without probing the
239+
# drafter directly; approximate as a small fraction of
240+
# the target step (drafter is ~1/10th model size).
241+
draft_step_ms = max(1.0, target_step_ms * 0.15)
242+
243+
self.tree_speculative_engine = TreeSpeculativeEngine(
244+
cfg=tree_speculative_config,
245+
target_verifier=verifier,
246+
draft_model=draft,
247+
cost_coefficients=cost_coefficients,
248+
target_step_ms=target_step_ms,
249+
draft_step_ms=draft_step_ms,
250+
mode=speculative_config.mode,
251+
temperature=speculative_config.temperature,
252+
stats=self._speculative_stats,
253+
)
254+
255+
# Build the auto-selector. Shares the tree engine's
256+
# EWMA so its scoring reads the same observations the
257+
# tree engine writes after every round.
258+
from .speculative.mode_selector import ModeSelector
259+
self.mode_selector = ModeSelector(
260+
target_step_ms=target_step_ms,
261+
draft_step_ms=draft_step_ms,
262+
flat_available=True,
263+
tree_available=True,
264+
tree_config=tree_speculative_config,
265+
flat_k=speculative_config.draft_k,
266+
flat_cold_accept=0.4,
267+
tree_ewma=self.tree_speculative_engine.ewma,
268+
cost_coefficients=cost_coefficients,
269+
)
187270
else:
188271
self._speculative_stats = None
189272

@@ -270,6 +353,35 @@ def reset_cache(self) -> None:
270353
"""
271354
self.cache_manager.clear()
272355

356+
def set_cost_coefficients(self, cc: Any) -> None:
357+
"""Populate cost coefficients post-construction.
358+
359+
The server probes coefficients AFTER engine load (the probe
360+
needs the loaded model), then plumbs them back here. They
361+
drive tree-shape selection and mode-auto-selection; setting
362+
them late just means the first request uses the defaults
363+
and subsequent requests are calibrated. Safe to call multiple
364+
times (e.g. if the operator updates them via /v1/stats).
365+
"""
366+
self.cost_coefficients = cc
367+
if self.tree_speculative_engine is not None:
368+
self.tree_speculative_engine.cc = cc
369+
# Update measured step latency if available — the tree
370+
# engine multiplies this by predicted node count, so a
371+
# bad value distorts every shape decision.
372+
try:
373+
self.tree_speculative_engine.target_step_ms = float(
374+
cc.step_latency_ms
375+
)
376+
except Exception:
377+
pass
378+
if self.mode_selector is not None:
379+
self.mode_selector.cc = cc
380+
try:
381+
self.mode_selector.target_step_ms = float(cc.step_latency_ms)
382+
except Exception:
383+
pass
384+
273385
def generate(
274386
self,
275387
prompt: str,
@@ -877,7 +989,11 @@ def _decode_with_kv(
877989
# We extend past_kv by one forward to cover that first sampled
878990
# token, then hand off — speculative's invariant is that past_kv
879991
# exactly covers the input prompt_ids.
880-
if self.speculative_engine is not None and len(generated) < max_new_tokens:
992+
any_spec = (
993+
self.speculative_engine is not None
994+
or self.tree_speculative_engine is not None
995+
)
996+
if any_spec and len(generated) < max_new_tokens:
881997
extended_pos = cached_len + len(live_ids)
882998
first_t = torch.tensor(
883999
[[generated[-1]]], dtype=torch.long, device=self.device
@@ -896,11 +1012,16 @@ def _decode_with_kv(
8961012
extended_prompt_ids = list(full_token_ids) + [generated[-1]]
8971013
from .speculative.bridge import run_speculative_decode
8981014

1015+
tree_cfg = self.tree_speculative_config
1016+
policy = tree_cfg.policy if tree_cfg is not None else "auto"
8991017
spec_generated, past_kv = run_speculative_decode(
9001018
full_token_ids=extended_prompt_ids,
9011019
target_past_kv=past_kv,
9021020
cached_length=len(extended_prompt_ids),
9031021
spec_engine=self.speculative_engine,
1022+
tree_engine=self.tree_speculative_engine,
1023+
mode_selector=self.mode_selector,
1024+
policy=policy,
9041025
max_new_tokens=max_new_tokens - len(generated),
9051026
eos_token_id=self.tokenizer.eos_token_id,
9061027
on_token=on_token,

src/kvboost/server/__main__.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,34 @@ def parse_args():
157157
help="Temperature applied to target logits in sampling mode "
158158
"(default: 1.0). Ignored in greedy mode.")
159159

160+
# SpecBlock-inspired tree speculative decoding. Requires the flat
161+
# speculative drafter to be set (uses the same draft model with a
162+
# tree-drafting wrapper). The ``ModeSelector`` then picks per-request
163+
# between flat-K and tree-(B,D) by expected wall-time tokens/s.
164+
p.add_argument("--speculative-tree", action="store_true", default=False,
165+
help="Enable SpecBlock-inspired tree speculative "
166+
"decoding alongside flat. Requires --speculative-"
167+
"draft-model. Per-request mode is auto-selected "
168+
"by the cost model unless --speculative-mode-policy "
169+
"overrides.")
170+
p.add_argument("--speculative-mode-policy", default=None,
171+
choices=["auto", "flat", "tree", "none"],
172+
help="Force one speculative mode per request. Default "
173+
"is 'auto' when --speculative-tree is set, else "
174+
"'flat'. 'none' disables speculation entirely.")
175+
p.add_argument("--speculative-tree-max-branching", type=int, default=4,
176+
help="Cap on per-node children in the draft tree "
177+
"(default: 4). Higher = wider tree.")
178+
p.add_argument("--speculative-tree-max-depth", type=int, default=6,
179+
help="Cap on tree depth (default: 6). Deeper trees "
180+
"win more when acceptance is high.")
181+
p.add_argument("--speculative-tree-node-budget", type=int, default=32,
182+
help="Total node-count cap for the tree (default: 32). "
183+
"Hard-bounds the target verifier's cost.")
184+
p.add_argument("--speculative-tree-cold-accept", type=float, default=0.5,
185+
help="Seed acceptance prior for the tree EWMA (default: "
186+
"0.5). Used until 16+ samples per (B,D) cohort.")
187+
160188
# Server
161189
p.add_argument("--host", default="0.0.0.0")
162190
p.add_argument("--port", type=int, default=8000)
@@ -363,6 +391,7 @@ def load_engine(args):
363391
args.streaming_quant_kernel,
364392
)
365393
speculative_cfg = _build_speculative_config(args)
394+
tree_speculative_cfg = _build_tree_speculative_config(args)
366395
engine = InferenceEngine.from_pretrained(
367396
args.model,
368397
streaming_config=streaming_config,
@@ -375,6 +404,7 @@ def load_engine(args):
375404
prefill_chunk_size=args.prefill_chunk_size,
376405
device=device,
377406
speculative_config=speculative_cfg,
407+
tree_speculative_config=tree_speculative_cfg,
378408
)
379409
log.info("Model loaded.")
380410
return engine
@@ -421,6 +451,7 @@ def load_engine(args):
421451
prefill_chunk_size=args.prefill_chunk_size,
422452
device=device,
423453
speculative_config=_build_speculative_config(args),
454+
tree_speculative_config=_build_tree_speculative_config(args),
424455
)
425456

426457
log.info("Model loaded.")
@@ -441,6 +472,34 @@ def _build_speculative_config(args):
441472
)
442473

443474

475+
def _build_tree_speculative_config(args):
476+
"""Build a TreeSpeculativeConfig from parsed CLI args, or return None
477+
when tree mode is disabled.
478+
479+
Requires the flat drafter (we reuse the same draft model wrapped
480+
by ``TreeDraftModel``). When ``--speculative-tree`` is set but no
481+
drafter is configured, raise a SystemExit with a clear message —
482+
silently disabling tree mode would mask a misconfiguration.
483+
"""
484+
if not getattr(args, "speculative_tree", False):
485+
return None
486+
if not getattr(args, "speculative_draft_model", None):
487+
raise SystemExit(
488+
"ERROR: --speculative-tree requires --speculative-draft-model "
489+
"(the tree drafter wraps the same small model). Pass both, "
490+
"or drop --speculative-tree."
491+
)
492+
from ..speculative import TreeSpeculativeConfig
493+
policy = getattr(args, "speculative_mode_policy", None) or "auto"
494+
return TreeSpeculativeConfig(
495+
max_branching=args.speculative_tree_max_branching,
496+
max_depth=args.speculative_tree_max_depth,
497+
node_budget=args.speculative_tree_node_budget,
498+
cold_accept=args.speculative_tree_cold_accept,
499+
policy=policy,
500+
)
501+
502+
444503
def main():
445504
args = parse_args()
446505

@@ -502,6 +561,10 @@ def main():
502561
"OOM planning enabled: auto_truncate=%s, safety_margin=%.0f%%",
503562
args.auto_truncate, args.planner_safety_margin * 100,
504563
)
564+
# Same coefficients drive tree-shape selection. The engine
565+
# already constructed its tree engine with defaults; this
566+
# writes the calibrated values in.
567+
engine.set_cost_coefficients(cost_coefficients)
505568

506569
worker = EngineWorker(
507570
engine=engine,

src/kvboost/speculative/__init__.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,21 @@
2727
from .config import SpeculativeConfig, SpeculativeMode
2828
from .draft import DraftModel
2929
from .engine import SpeculativeEngine
30-
from .rollback import truncate_past_kv
30+
from .mode_selector import ChosenMode, ModeSelector
31+
from .rollback import gather_kv_columns, truncate_past_kv
3132
from .sampler import verify_greedy, verify_sampling
3233
from .stats import SpeculativeStats
34+
from .tree import (
35+
AcceptanceEWMA,
36+
TreeShape,
37+
TreeSpeculativeConfig,
38+
pick_shape,
39+
)
40+
from .tree.engine import TreeSpeculativeEngine
3341
from .verifier import TargetVerifier
3442

3543
__all__ = [
44+
# flat
3645
"SpeculativeConfig",
3746
"SpeculativeMode",
3847
"SpeculativeEngine",
@@ -42,5 +51,15 @@
4251
"verify_greedy",
4352
"verify_sampling",
4453
"truncate_past_kv",
54+
"gather_kv_columns",
4555
"run_speculative_decode",
56+
# tree
57+
"TreeSpeculativeConfig",
58+
"TreeSpeculativeEngine",
59+
"TreeShape",
60+
"AcceptanceEWMA",
61+
"pick_shape",
62+
# mode selection
63+
"ModeSelector",
64+
"ChosenMode",
4665
]

0 commit comments

Comments
 (0)