@@ -379,42 +379,75 @@ def build_tree_kernel_efficient(
379379@register_pytree_node_class
380380@dataclass
381381class EagleDraftInput :
382- # Constant: alloc length per decode step
382+ """Next-round draft state — the only persistent cross-round spec state.
383+
384+ Implements ``SpecInput``. MUST NOT hold worker/runner/pool/future handles.
385+ Under DP (Route 1), per-request fields use DP-padded order.
386+ """
387+
383388 ALLOC_LEN_PER_DECODE : ClassVar [int ] = None
384389
385- # The inputs for decode
386- # shape: (b, topk)
387- topk_p : np .ndarray = None
388- topk_index : np .ndarray = None
389- # shape: (b, hidden_size)
390- hidden_states : np .ndarray = None
390+ # --- Cross-round draft state (device arrays, consumed by next draft) ---
391+ #: device ``(b, topk)`` — top-k probs from previous draft/draft_extend.
392+ topk_p : jax .Array | None = None
393+ #: device ``(b, topk)`` — top-k token ids.
394+ topk_index : jax .Array | None = None
395+ #: device ``(b, hidden_size)`` — minimal hidden state for next draft step.
396+ #: Multi-layer MTP keeps per-step hidden locally inside one
397+ #: ``MultiLayerDraftWorker.draft()``; only this cross-round slice persists.
398+ hidden_states : jax .Array | None = None
399+ #: static metadata (pytree aux); changing it triggers a new compile shape.
391400 capture_hidden_mode : CaptureHiddenMode = CaptureHiddenMode .FULL
392401
393- # Inputs for extend
394- # shape: (b,)
395- verified_id : np .ndarray = None
396- accept_length : np .ndarray = None
402+ # --- Draft-extend inputs (device unless ``_cpu`` suffixed) ---
403+ #: device ``(b,)`` — verified token starting the next draft.
404+ verified_id : jax .Array | None = None
405+ #: device ``(b,)`` — accepted length used to select hidden in draft-extend.
406+ accept_length : jax .Array | None = None
407+ #: host ``(b,)`` int32 mirror of ``accept_length`` for scheduler bookkeeping.
397408 accept_length_cpu : np .ndarray | None = None
398409
399- # Inputs for the attention backends
400- # shape: (b + 1,)
401- kv_indptr : np .ndarray = None
402- kv_indices : np .ndarray = None
410+ # --- Attention-backend metadata (host, participates in metadata build) ---
411+ kv_indptr : np .ndarray | None = None
412+ kv_indices : np .ndarray | None = None
403413
404- # Shape info for padding
414+ # --- Padding shape (static; participates in JIT cache key) ---
405415 num_tokens_per_batch : int = - 1
406416 num_tokens_for_logprob_per_batch : int = - 1
407417
408- # Inputs for draft extend
409- # shape: (b,)
410- seq_lens_for_draft_extend : np .ndarray = None
411- req_pool_indices_for_draft_extend : np .ndarray = None
418+ # --- Draft-extend bookkeeping (host) ---
419+ seq_lens_for_draft_extend : np .ndarray | None = None
420+ req_pool_indices_for_draft_extend : np .ndarray | None = None
412421
413- # Inputs for V2 overlap worker
414- # future_indices: Optional[FutureIndices] = None
422+ # --- KV lifetime (host, scheduler-visible) ---
423+ #: host ``(b,)`` — KV length already allocated in ``req_to_token_pool`` for
424+ #: next-round pre-allocation and over-allocated slot release. Distinct from
425+ #: ``accept_length`` (logical) and ``new_seq_lens`` (scheduler-visible).
415426 allocate_lens : np .ndarray | None = None
427+ #: host ``(b,)`` — scheduler-visible logical length after verify. May be
428+ #: derived from ``old_seq_lens + accept_length`` if not stored.
416429 new_seq_lens : np .ndarray | None = None
417- # verify_done: Optional[torch.cuda.Event] = None
430+
431+ # ---- SpecInput protocol -------------------------------------------------
432+ def is_draft_input (self ) -> bool :
433+ return True
434+
435+ def is_verify_input (self ) -> bool :
436+ return False
437+
438+ def get_spec_adjust_token_coefficient (self ) -> int :
439+ return EagleDraftInput .ALLOC_LEN_PER_DECODE or 1
440+
441+ def get_logical_token_num (self , bs : int ) -> np .ndarray :
442+ if self .accept_length_cpu is not None :
443+ return self .accept_length_cpu
444+ return np .ones (bs , dtype = np .int32 )
445+
446+ def get_allocated_token_num (self ) -> np .ndarray | None :
447+ return self .allocate_lens
448+
449+ def get_verify_token_num (self , bs : int ) -> int :
450+ return 0
418451
419452 def tree_flatten (self ):
420453 accept_length_cpu_arr = (
@@ -662,11 +695,10 @@ def merge_batch(self, spec_info: EagleDraftInput):
662695 return
663696 if spec_info .hidden_states is None :
664697 return
665- # FIXME(pc) this operate should be put on cpu
666- self .hidden_states = np .concatenate ([self .hidden_states , spec_info .hidden_states ], axis = 0 )
667- self .verified_id = np .concatenate ([self .verified_id , spec_info .verified_id ], axis = 0 )
668- self .topk_p = np .concatenate ([self .topk_p , spec_info .topk_p ])
669- self .topk_index = np .concatenate ([self .topk_index , spec_info .topk_index ])
698+ self .hidden_states = jnp .concatenate ([self .hidden_states , spec_info .hidden_states ], axis = 0 )
699+ self .verified_id = jnp .concatenate ([self .verified_id , spec_info .verified_id ], axis = 0 )
700+ self .topk_p = jnp .concatenate ([self .topk_p , spec_info .topk_p ])
701+ self .topk_index = jnp .concatenate ([self .topk_index , spec_info .topk_index ])
670702 self .allocate_lens = np .concatenate ([self .allocate_lens , spec_info .allocate_lens ])
671703
672704
@@ -687,22 +719,65 @@ class EagleVerifyOutput:
687719@register_pytree_node_class
688720@dataclass
689721class EagleVerifyInput :
690- # container type for pytree
722+ """Target-verify input. Implements ``SpecInput``.
723+
724+ Fully describes token/position/mask/tree-index for verify so
725+ ``BaseSpecWorker.verify()`` never reads draft-worker internal state.
726+ Under DP (Route 1), per-request fields use DP-padded order; verify
727+ metadata must reshape to per-DP view before generating cu_q/kv_lens.
728+ """
729+
730+ # --- Device arrays (enter target verify forward / sampling) ---
731+ #: device ``(b*draft_token_num,)`` — flattened draft tokens to verify.
691732 draft_token : jax .Array
733+ #: device ``(sum(q_i*kv_i),)`` — tree attention mask; shape participates
734+ #: in the JIT cache key.
692735 custom_mask : jax .Array
736+ #: device ``(b*draft_token_num,)`` — verify positions (follows
737+ #: ``ForwardBatch`` host/device convention).
693738 positions : jax .Array
739+ #: device — tree verify index (sampling-kernel convention).
694740 retrive_index : jax .Array
741+ #: device — tree child pointer for tree sampling.
695742 retrive_next_token : jax .Array
743+ #: device — tree sibling pointer for tree sampling.
696744 retrive_next_sibling : jax .Array
697745 retrive_cum_len : jax .Array
746+ #: host ``(b,)`` — for verify attention metadata + DP token accounting.
698747 seq_lens_cpu : np .ndarray
699- # common type for pytree
748+
749+ # --- Static metadata (pytree aux; changes trigger new compile shape) ---
700750 spec_steps : int
701751 topk : int
752+ #: per-request verify token count (constant within a precompile shape).
702753 draft_token_num : int
703754 seq_lens_sum : int
704755 capture_hidden_mode : CaptureHiddenMode
705- # grammar: BaseGrammarObject = None
756+
757+ # ---- SpecInput protocol -------------------------------------------------
758+ def is_draft_input (self ) -> bool :
759+ return False
760+
761+ def is_verify_input (self ) -> bool :
762+ return True
763+
764+ def get_spec_adjust_token_coefficient (self ) -> int :
765+ return self .draft_token_num
766+
767+ def get_logical_token_num (self , bs : int ) -> np .ndarray :
768+ return np .ones (bs , dtype = np .int32 )
769+
770+ def get_allocated_token_num (self ) -> np .ndarray | None :
771+ return None
772+
773+ def get_verify_token_num (self , bs : int ) -> int :
774+ return bs * self .draft_token_num
775+
776+ def filter_batch (self , new_indices : np .ndarray , has_been_filtered : bool = True ) -> None :
777+ raise NotImplementedError ("EagleVerifyInput is consumed within one round" )
778+
779+ def merge_batch (self , other ) -> None :
780+ raise NotImplementedError ("EagleVerifyInput is consumed within one round" )
706781
707782 def tree_flatten (self ):
708783 seq_lens_sum_arr = _as_int32_array (self .seq_lens_sum , fallback = 0 )
0 commit comments