Skip to content

Commit 17db854

Browse files
authored
refactor(spec): split EAGLEWorker into BaseSpecWorker/BaseDraftWorker… (#1080)
* refactor(spec): split EAGLEWorker into BaseSpecWorker/BaseDraftWorker + EagleDraftWorker (P1-2)
1 parent ffa86d4 commit 17db854

3 files changed

Lines changed: 766 additions & 643 deletions

File tree

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from typing import TYPE_CHECKING
5+
6+
import jax
7+
from jax.sharding import NamedSharding
8+
from jax.sharding import PartitionSpec as P
9+
10+
if TYPE_CHECKING:
11+
from sgl_jax.srt.managers.tp_worker import ModelWorker
12+
13+
14+
def replicate_to_mesh(
15+
mesh: jax.sharding.Mesh, *arrs: jax.Array
16+
) -> tuple[jax.Array, ...] | jax.Array:
17+
"""Replicate arrays across a mesh under explicit sharding.
18+
19+
JIT outputs are typically vocab/data-sharded; spec-decode host orchestration
20+
(top_k, gather, build_tree) needs replicated arrays.
21+
"""
22+
out = jax.device_put(arrs, NamedSharding(mesh, P()))
23+
return out[0] if len(out) == 1 else out
24+
25+
26+
class BaseDraftWorker(ABC):
27+
"""Draft model worker interface for speculative decoding.
28+
29+
Concrete implementations hold the draft model runner and own all
30+
draft-specific logic (multi-step decode, tree building, extend).
31+
Standard EAGLE uses ``EagleDraftWorker``; MTP will use
32+
``MultiLayerDraftWorker``.
33+
"""
34+
35+
@abstractmethod
36+
def draft(self):
37+
pass
38+
39+
40+
class BaseSpecWorker(ABC):
41+
"""Speculative decode orchestrator.
42+
43+
Owns a ``target_worker`` (the full model) and a ``draft_worker``
44+
(the draft/MTP model). Concrete subclasses implement the main
45+
entry point and connect their specific draft/verify logic.
46+
"""
47+
48+
@property
49+
@abstractmethod
50+
def target_worker(self) -> ModelWorker:
51+
pass
52+
53+
@property
54+
@abstractmethod
55+
def draft_worker(self) -> BaseDraftWorker:
56+
pass

0 commit comments

Comments
 (0)