3030import time
3131from contextlib import contextmanager
3232from dataclasses import dataclass , field
33- from typing import Callable , Dict , List , Optional , Set , Tuple
33+ from typing import Any , Callable , Dict , List , Optional , Set , TYPE_CHECKING , Tuple
34+
35+ if TYPE_CHECKING :
36+ from .speculative .tree .config import TreeSpeculativeConfig
3437
3538import torch
3639from transformers import AutoModelForCausalLM , AutoTokenizer
@@ -101,6 +104,11 @@ def __init__(
101104 prefill_chunk_size : int = 0 ,
102105 # Speculative decoding (None = disabled, baseline decode path)
103106 speculative_config : Optional ["SpeculativeConfig" ] = None ,
107+ # Tree speculative — SpecBlock-inspired, may coexist with flat
108+ tree_speculative_config : Optional ["TreeSpeculativeConfig" ] = None ,
109+ # Cost coefficients (probed at server startup) for cost-aware
110+ # tree shape + mode selection. None = degraded mode (defaults).
111+ cost_coefficients : Any = None ,
104112 ):
105113 if device is None :
106114 device = default_device ()
@@ -160,30 +168,105 @@ def __init__(
160168
161169 # Speculative decoding (decode-phase orthogonal to recompute_strategy).
162170 # CacheBlend handles prefill; speculative handles decode. They stack.
171+ # Two flavors: flat (token-by-token K draft) and tree (SpecBlock-
172+ # inspired). Both may be present; the bridge / ModeSelector picks
173+ # per-request when so.
163174 self .speculative_config = speculative_config
175+ self .tree_speculative_config = tree_speculative_config
176+ self .cost_coefficients = cost_coefficients
164177 self .speculative_engine = None
165- if speculative_config is not None :
166- speculative_config .validate ()
178+ self .tree_speculative_engine = None
179+ self .mode_selector = None
180+
181+ need_draft = (
182+ speculative_config is not None
183+ or tree_speculative_config is not None
184+ )
185+ if need_draft :
167186 from .speculative .draft import DraftModel
168- from .speculative .engine import SpeculativeEngine
169187 from .speculative .stats import SpeculativeStats
170188 from .speculative .verifier import TargetVerifier
171189
190+ # Validate whichever configs are present. The DraftModel
191+ # itself needs a flat-style ``SpeculativeConfig`` so its
192+ # model-load path stays one code path (the tree config
193+ # doesn't carry draft_model_id / draft_streaming_config; if
194+ # only tree is wired we still rely on the flat config for
195+ # the drafter handle).
196+ if speculative_config is not None :
197+ speculative_config .validate ()
198+ if tree_speculative_config is not None :
199+ tree_speculative_config .validate ()
200+
201+ # The drafter is shared across flat + tree.
202+ if speculative_config is None :
203+ raise ValueError (
204+ "tree_speculative_config requires a flat "
205+ "SpeculativeConfig (drafter model handle); pass "
206+ "both."
207+ )
208+
209+ self ._speculative_stats = SpeculativeStats ()
172210 log .info (
173- "Speculative decoding enabled: %s" ,
211+ "Speculative decoding enabled: flat=%s tree= %s" ,
174212 speculative_config .summary (),
213+ tree_speculative_config .summary ()
214+ if tree_speculative_config else "off" ,
175215 )
176- self ._speculative_stats = SpeculativeStats ()
177216 draft = DraftModel (
178217 speculative_config , target_tokenizer = tokenizer
179218 )
180219 verifier = TargetVerifier (self .model , device = device )
220+
221+ # Flat engine: existing path, unchanged.
222+ from .speculative .engine import SpeculativeEngine
181223 self .speculative_engine = SpeculativeEngine (
182224 cfg = speculative_config ,
183225 target_verifier = verifier ,
184226 draft_model = draft ,
185227 stats = self ._speculative_stats ,
186228 )
229+
230+ # Tree engine: only when its config is provided.
231+ if tree_speculative_config is not None :
232+ from .speculative .tree .engine import TreeSpeculativeEngine
233+
234+ target_step_ms = (
235+ cost_coefficients .step_latency_ms
236+ if cost_coefficients is not None else 50.0
237+ )
238+ # Draft step latency is unknown without probing the
239+ # drafter directly; approximate as a small fraction of
240+ # the target step (drafter is ~1/10th model size).
241+ draft_step_ms = max (1.0 , target_step_ms * 0.15 )
242+
243+ self .tree_speculative_engine = TreeSpeculativeEngine (
244+ cfg = tree_speculative_config ,
245+ target_verifier = verifier ,
246+ draft_model = draft ,
247+ cost_coefficients = cost_coefficients ,
248+ target_step_ms = target_step_ms ,
249+ draft_step_ms = draft_step_ms ,
250+ mode = speculative_config .mode ,
251+ temperature = speculative_config .temperature ,
252+ stats = self ._speculative_stats ,
253+ )
254+
255+ # Build the auto-selector. Shares the tree engine's
256+ # EWMA so its scoring reads the same observations the
257+ # tree engine writes after every round.
258+ from .speculative .mode_selector import ModeSelector
259+ self .mode_selector = ModeSelector (
260+ target_step_ms = target_step_ms ,
261+ draft_step_ms = draft_step_ms ,
262+ flat_available = True ,
263+ tree_available = True ,
264+ tree_config = tree_speculative_config ,
265+ flat_k = speculative_config .draft_k ,
266+ flat_cold_accept = 0.4 ,
267+ tree_ewma = self .tree_speculative_engine .ewma ,
268+ cost_coefficients = cost_coefficients ,
269+ )
187270 else :
188271 self ._speculative_stats = None
189272
@@ -270,6 +353,35 @@ def reset_cache(self) -> None:
270353 """
271354 self .cache_manager .clear ()
272355
356+ def set_cost_coefficients (self , cc : Any ) -> None :
357+ """Populate cost coefficients post-construction.
358+
359+ The server probes coefficients AFTER engine load (the probe
360+ needs the loaded model), then plumbs them back here. They
361+ drive tree-shape selection and mode-auto-selection; setting
362+ them late just means the first request uses the defaults
363+ and subsequent requests are calibrated. Safe to call multiple
364+ times (e.g. if the operator updates them via /v1/stats).
365+ """
366+ self .cost_coefficients = cc
367+ if self .tree_speculative_engine is not None :
368+ self .tree_speculative_engine .cc = cc
369+ # Update measured step latency if available — the tree
370+ # engine multiplies this by predicted node count, so a
371+ # bad value distorts every shape decision.
372+ try :
373+ self .tree_speculative_engine .target_step_ms = float (
374+ cc .step_latency_ms
375+ )
376+ except Exception :
377+ pass
378+ if self .mode_selector is not None :
379+ self .mode_selector .cc = cc
380+ try :
381+ self .mode_selector .target_step_ms = float (cc .step_latency_ms )
382+ except Exception :
383+ pass
384+
273385 def generate (
274386 self ,
275387 prompt : str ,
@@ -877,7 +989,11 @@ def _decode_with_kv(
877989 # We extend past_kv by one forward to cover that first sampled
878990 # token, then hand off — speculative's invariant is that past_kv
879991 # exactly covers the input prompt_ids.
880- if self .speculative_engine is not None and len (generated ) < max_new_tokens :
992+ any_spec = (
993+ self .speculative_engine is not None
994+ or self .tree_speculative_engine is not None
995+ )
996+ if any_spec and len (generated ) < max_new_tokens :
881997 extended_pos = cached_len + len (live_ids )
882998 first_t = torch .tensor (
883999 [[generated [- 1 ]]], dtype = torch .long , device = self .device
@@ -896,11 +1012,16 @@ def _decode_with_kv(
8961012 extended_prompt_ids = list (full_token_ids ) + [generated [- 1 ]]
8971013 from .speculative .bridge import run_speculative_decode
8981014
1015+ tree_cfg = self .tree_speculative_config
1016+ policy = tree_cfg .policy if tree_cfg is not None else "auto"
8991017 spec_generated , past_kv = run_speculative_decode (
9001018 full_token_ids = extended_prompt_ids ,
9011019 target_past_kv = past_kv ,
9021020 cached_length = len (extended_prompt_ids ),
9031021 spec_engine = self .speculative_engine ,
1022+ tree_engine = self .tree_speculative_engine ,
1023+ mode_selector = self .mode_selector ,
1024+ policy = policy ,
9041025 max_new_tokens = max_new_tokens - len (generated ),
9051026 eos_token_id = self .tokenizer .eos_token_id ,
9061027 on_token = on_token ,
0 commit comments