Skip to content

Commit f4d39f2

Browse files
authored
feat(spec): SpecInput data contract + host/device boundary annotations (#1064)
Part of #1053 (P1-5a). Defines the spec-decode data contract that #1053 P1-2 (BaseSpecWorker/BaseDraftWorker class refactor) will build on, so the abstraction layer has a fixed interface to target. - spec_info.SpecInput: runtime_checkable Protocol exposing the three token counts the RFC requires to be separated (logical / allocated / verify) plus is_draft_input/is_verify_input/get_spec_adjust_token_coefficient /filter_batch/merge_batch. Docstring fixes the host/device boundary and DP-padded layout (Route 1) semantics. - EagleDraftInput: implement SpecInput; per-field docstrings annotate device vs host, shape, JIT-cache-key participation, and the allocate_lens vs accept_length vs new_seq_lens distinction. - EagleVerifyInput: implement SpecInput; per-field docstrings annotate device vs host and static-metadata fields; filter/merge raise (verify input is single-round). - test_spec_info.py: Protocol conformance + three-token-count distinction. No behavior change to the existing dp=1 path (pure additions: docstrings, Protocol, new methods). Type annotations on device fields changed np.ndarray → jax.Array | None to reflect actual runtime types after #1063. Depends on #1063 (P1-0) — diff is on top of that branch.
1 parent 5247c34 commit f4d39f2

4 files changed

Lines changed: 208 additions & 31 deletions

File tree

python/sgl_jax/srt/speculative/eagle_util.py

Lines changed: 106 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -379,42 +379,75 @@ def build_tree_kernel_efficient(
379379
@register_pytree_node_class
380380
@dataclass
381381
class EagleDraftInput:
382-
# Constant: alloc length per decode step
382+
"""Next-round draft state — the only persistent cross-round spec state.
383+
384+
Implements ``SpecInput``. MUST NOT hold worker/runner/pool/future handles.
385+
Under DP (Route 1), per-request fields use DP-padded order.
386+
"""
387+
383388
ALLOC_LEN_PER_DECODE: ClassVar[int] = None
384389

385-
# The inputs for decode
386-
# shape: (b, topk)
387-
topk_p: np.ndarray = None
388-
topk_index: np.ndarray = None
389-
# shape: (b, hidden_size)
390-
hidden_states: np.ndarray = None
390+
# --- Cross-round draft state (device arrays, consumed by next draft) ---
391+
#: device ``(b, topk)`` — top-k probs from previous draft/draft_extend.
392+
topk_p: jax.Array | None = None
393+
#: device ``(b, topk)`` — top-k token ids.
394+
topk_index: jax.Array | None = None
395+
#: device ``(b, hidden_size)`` — minimal hidden state for next draft step.
396+
#: Multi-layer MTP keeps per-step hidden locally inside one
397+
#: ``MultiLayerDraftWorker.draft()``; only this cross-round slice persists.
398+
hidden_states: jax.Array | None = None
399+
#: static metadata (pytree aux); changing it triggers a new compile shape.
391400
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
392401

393-
# Inputs for extend
394-
# shape: (b,)
395-
verified_id: np.ndarray = None
396-
accept_length: np.ndarray = None
402+
# --- Draft-extend inputs (device unless ``_cpu`` suffixed) ---
403+
#: device ``(b,)`` — verified token starting the next draft.
404+
verified_id: jax.Array | None = None
405+
#: device ``(b,)`` — accepted length used to select hidden in draft-extend.
406+
accept_length: jax.Array | None = None
407+
#: host ``(b,)`` int32 mirror of ``accept_length`` for scheduler bookkeeping.
397408
accept_length_cpu: np.ndarray | None = None
398409

399-
# Inputs for the attention backends
400-
# shape: (b + 1,)
401-
kv_indptr: np.ndarray = None
402-
kv_indices: np.ndarray = None
410+
# --- Attention-backend metadata (host, participates in metadata build) ---
411+
kv_indptr: np.ndarray | None = None
412+
kv_indices: np.ndarray | None = None
403413

404-
# Shape info for padding
414+
# --- Padding shape (static; participates in JIT cache key) ---
405415
num_tokens_per_batch: int = -1
406416
num_tokens_for_logprob_per_batch: int = -1
407417

408-
# Inputs for draft extend
409-
# shape: (b,)
410-
seq_lens_for_draft_extend: np.ndarray = None
411-
req_pool_indices_for_draft_extend: np.ndarray = None
418+
# --- Draft-extend bookkeeping (host) ---
419+
seq_lens_for_draft_extend: np.ndarray | None = None
420+
req_pool_indices_for_draft_extend: np.ndarray | None = None
412421

413-
# Inputs for V2 overlap worker
414-
# future_indices: Optional[FutureIndices] = None
422+
# --- KV lifetime (host, scheduler-visible) ---
423+
#: host ``(b,)`` — KV length already allocated in ``req_to_token_pool`` for
424+
#: next-round pre-allocation and over-allocated slot release. Distinct from
425+
#: ``accept_length`` (logical) and ``new_seq_lens`` (scheduler-visible).
415426
allocate_lens: np.ndarray | None = None
427+
#: host ``(b,)`` — scheduler-visible logical length after verify. May be
428+
#: derived from ``old_seq_lens + accept_length`` if not stored.
416429
new_seq_lens: np.ndarray | None = None
417-
# verify_done: Optional[torch.cuda.Event] = None
430+
431+
# ---- SpecInput protocol -------------------------------------------------
432+
def is_draft_input(self) -> bool:
433+
return True
434+
435+
def is_verify_input(self) -> bool:
436+
return False
437+
438+
def get_spec_adjust_token_coefficient(self) -> int:
439+
return EagleDraftInput.ALLOC_LEN_PER_DECODE or 1
440+
441+
def get_logical_token_num(self, bs: int) -> np.ndarray:
442+
if self.accept_length_cpu is not None:
443+
return self.accept_length_cpu
444+
return np.ones(bs, dtype=np.int32)
445+
446+
def get_allocated_token_num(self) -> np.ndarray | None:
447+
return self.allocate_lens
448+
449+
def get_verify_token_num(self, bs: int) -> int:
450+
return 0
418451

419452
def tree_flatten(self):
420453
accept_length_cpu_arr = (
@@ -662,11 +695,10 @@ def merge_batch(self, spec_info: EagleDraftInput):
662695
return
663696
if spec_info.hidden_states is None:
664697
return
665-
# FIXME(pc) this operate should be put on cpu
666-
self.hidden_states = np.concatenate([self.hidden_states, spec_info.hidden_states], axis=0)
667-
self.verified_id = np.concatenate([self.verified_id, spec_info.verified_id], axis=0)
668-
self.topk_p = np.concatenate([self.topk_p, spec_info.topk_p])
669-
self.topk_index = np.concatenate([self.topk_index, spec_info.topk_index])
698+
self.hidden_states = jnp.concatenate([self.hidden_states, spec_info.hidden_states], axis=0)
699+
self.verified_id = jnp.concatenate([self.verified_id, spec_info.verified_id], axis=0)
700+
self.topk_p = jnp.concatenate([self.topk_p, spec_info.topk_p])
701+
self.topk_index = jnp.concatenate([self.topk_index, spec_info.topk_index])
670702
self.allocate_lens = np.concatenate([self.allocate_lens, spec_info.allocate_lens])
671703

672704

@@ -687,22 +719,65 @@ class EagleVerifyOutput:
687719
@register_pytree_node_class
688720
@dataclass
689721
class EagleVerifyInput:
690-
# container type for pytree
722+
"""Target-verify input. Implements ``SpecInput``.
723+
724+
Fully describes token/position/mask/tree-index for verify so
725+
``BaseSpecWorker.verify()`` never reads draft-worker internal state.
726+
Under DP (Route 1), per-request fields use DP-padded order; verify
727+
metadata must reshape to per-DP view before generating cu_q/kv_lens.
728+
"""
729+
730+
# --- Device arrays (enter target verify forward / sampling) ---
731+
#: device ``(b*draft_token_num,)`` — flattened draft tokens to verify.
691732
draft_token: jax.Array
733+
#: device ``(sum(q_i*kv_i),)`` — tree attention mask; shape participates
734+
#: in the JIT cache key.
692735
custom_mask: jax.Array
736+
#: device ``(b*draft_token_num,)`` — verify positions (follows
737+
#: ``ForwardBatch`` host/device convention).
693738
positions: jax.Array
739+
#: device — tree verify index (sampling-kernel convention).
694740
retrive_index: jax.Array
741+
#: device — tree child pointer for tree sampling.
695742
retrive_next_token: jax.Array
743+
#: device — tree sibling pointer for tree sampling.
696744
retrive_next_sibling: jax.Array
697745
retrive_cum_len: jax.Array
746+
#: host ``(b,)`` — for verify attention metadata + DP token accounting.
698747
seq_lens_cpu: np.ndarray
699-
# common type for pytree
748+
749+
# --- Static metadata (pytree aux; changes trigger new compile shape) ---
700750
spec_steps: int
701751
topk: int
752+
#: per-request verify token count (constant within a precompile shape).
702753
draft_token_num: int
703754
seq_lens_sum: int
704755
capture_hidden_mode: CaptureHiddenMode
705-
# grammar: BaseGrammarObject = None
756+
757+
# ---- SpecInput protocol -------------------------------------------------
758+
def is_draft_input(self) -> bool:
759+
return False
760+
761+
def is_verify_input(self) -> bool:
762+
return True
763+
764+
def get_spec_adjust_token_coefficient(self) -> int:
765+
return self.draft_token_num
766+
767+
def get_logical_token_num(self, bs: int) -> np.ndarray:
768+
return np.ones(bs, dtype=np.int32)
769+
770+
def get_allocated_token_num(self) -> np.ndarray | None:
771+
return None
772+
773+
def get_verify_token_num(self, bs: int) -> int:
774+
return bs * self.draft_token_num
775+
776+
def filter_batch(self, new_indices: np.ndarray, has_been_filtered: bool = True) -> None:
777+
raise NotImplementedError("EagleVerifyInput is consumed within one round")
778+
779+
def merge_batch(self, other) -> None:
780+
raise NotImplementedError("EagleVerifyInput is consumed within one round")
706781

707782
def tree_flatten(self):
708783
seq_lens_sum_arr = _as_int32_array(self.seq_lens_sum, fallback=0)

python/sgl_jax/srt/speculative/spec_info.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,60 @@
1+
from __future__ import annotations
2+
13
import logging
24
from enum import IntEnum, auto
5+
from typing import Protocol, runtime_checkable
36

47
import jax
8+
import numpy as np
59

610
from sgl_jax.srt.layers.logits_processor import LogitsProcessorOutput
711

812
logger = logging.getLogger(__name__)
913

1014

15+
@runtime_checkable
16+
class SpecInput(Protocol):
17+
"""Common interface for speculative-decode state passed through
18+
``ModelWorkerBatch.spec_info`` (#1053 P1-5a data contract).
19+
20+
Separates three token counts that the scheduler / KV allocator / verify
21+
path each need but which differ under spec decode:
22+
23+
- **logical** — tokens the scheduler advances request output by
24+
(= accepted count incl. bonus). Host scalar/array.
25+
- **allocated** — KV slots already pre-allocated this round (for trimming
26+
over-allocation on finished reqs). Host array.
27+
- **verify** — flattened token count target verify will forward (drives
28+
verify attention metadata + DP token accounting). Host scalar.
29+
30+
Implementations MUST NOT hold worker/runner/pool/future/callback handles
31+
in pytree children (these would enter the JIT cache key). Device arrays
32+
(``topk_p``, ``hidden_states``, ``draft_token``, ...) stay on device;
33+
lengths/indices stay host-side ``np.ndarray``.
34+
35+
DP layout (Route 1, target+draft both DP): all per-request fields use
36+
DP-padded order — section ``[dp_rank*per_dp_bs : dp_rank*per_dp_bs+real_bs]``.
37+
Padding slots MUST NOT participate in valid state updates.
38+
"""
39+
40+
def is_draft_input(self) -> bool: ...
41+
def is_verify_input(self) -> bool: ...
42+
43+
def get_spec_adjust_token_coefficient(self) -> int:
44+
"""Multiplier for scheduler new-token budgeting (e.g. draft_token_num)."""
45+
...
46+
47+
def get_logical_token_num(self, bs: int) -> np.ndarray:
48+
"""Per-request host int32 ``(bs,)``; callers sum for batch totals."""
49+
...
50+
51+
def get_allocated_token_num(self) -> np.ndarray | None: ...
52+
def get_verify_token_num(self, bs: int) -> int: ...
53+
54+
def filter_batch(self, new_indices: np.ndarray, has_been_filtered: bool = True) -> None: ...
55+
def merge_batch(self, other: SpecInput) -> None: ...
56+
57+
1158
class SpeculativeAlgorithm(IntEnum):
1259
NONE = auto()
1360
EAGLE = auto()
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""SpecInput protocol conformance for EagleDraftInput / EagleVerifyInput."""
2+
3+
import numpy as np
4+
5+
from sgl_jax.srt.model_executor.forward_batch_info import CaptureHiddenMode
6+
from sgl_jax.srt.speculative.eagle_util import EagleDraftInput, EagleVerifyInput
7+
from sgl_jax.srt.speculative.spec_info import SpecInput
8+
9+
10+
def test_eagle_draft_input_is_spec_input():
11+
di = EagleDraftInput(
12+
accept_length_cpu=np.array([2, 3, 1], dtype=np.int32),
13+
allocate_lens=np.array([10, 12, 8], dtype=np.int32),
14+
)
15+
assert isinstance(di, SpecInput)
16+
assert di.is_draft_input() and not di.is_verify_input()
17+
assert (di.get_logical_token_num(bs=3) == np.array([2, 3, 1])).all()
18+
assert di.get_verify_token_num(bs=3) == 0
19+
assert (di.get_allocated_token_num() == np.array([10, 12, 8])).all()
20+
assert di.get_spec_adjust_token_coefficient() >= 1
21+
22+
23+
def test_eagle_verify_input_is_spec_input():
24+
vi = EagleVerifyInput(
25+
draft_token=np.zeros(8, dtype=np.int32),
26+
custom_mask=np.zeros(1, dtype=np.int32),
27+
positions=np.zeros(8, dtype=np.int32),
28+
retrive_index=np.zeros(8, dtype=np.int32),
29+
retrive_next_token=np.zeros(8, dtype=np.int32),
30+
retrive_next_sibling=np.zeros(8, dtype=np.int32),
31+
retrive_cum_len=np.zeros(3, dtype=np.int32),
32+
seq_lens_cpu=np.array([5, 7], dtype=np.int32),
33+
spec_steps=3,
34+
topk=1,
35+
draft_token_num=4,
36+
seq_lens_sum=12,
37+
capture_hidden_mode=CaptureHiddenMode.FULL,
38+
)
39+
assert isinstance(vi, SpecInput)
40+
assert vi.is_verify_input() and not vi.is_draft_input()
41+
assert vi.get_verify_token_num(bs=2) == 8
42+
assert vi.get_spec_adjust_token_coefficient() == 4
43+
assert vi.get_allocated_token_num() is None
44+
45+
46+
def test_three_token_counts_are_distinct():
47+
"""RFC #1053: logical / allocated / verify must be exposed independently."""
48+
di = EagleDraftInput(
49+
accept_length_cpu=np.array([2, 2], dtype=np.int32),
50+
allocate_lens=np.array([100, 100], dtype=np.int32),
51+
)
52+
assert int(di.get_logical_token_num(bs=2).sum()) == 4
53+
assert int(di.get_allocated_token_num().sum()) == 200
54+
assert di.get_verify_token_num(bs=2) == 0

test/srt/run_suite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ def run_one_file(
459459
TestFile("python/sgl_jax/test/mem_cache/test_hybrid_req_to_token_pool.py", 1),
460460
TestFile("python/sgl_jax/test/speculative/test_eagle_tree_build.py", 1),
461461
TestFile("python/sgl_jax/test/speculative/test_eagle_utils.py", 1),
462+
TestFile("python/sgl_jax/test/speculative/test_spec_info.py", 0.2, runner="pytest"),
462463
TestFile("python/sgl_jax/test/models/test_mimo_v2_nextn.py", 0.2, runner="pytest"),
463464
TestFile("python/sgl_jax/test/multimodal/test_wan_vae_precision.py", 1),
464465
TestFile("python/sgl_jax/test/multimodal/test_vae_scheduler.py", 2.5),

0 commit comments

Comments
 (0)