|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from abc import ABC, abstractmethod |
| 4 | +from typing import TYPE_CHECKING |
| 5 | + |
| 6 | +import jax |
| 7 | +from jax.sharding import NamedSharding |
| 8 | +from jax.sharding import PartitionSpec as P |
| 9 | + |
| 10 | +if TYPE_CHECKING: |
| 11 | + from sgl_jax.srt.managers.tp_worker import ModelWorker |
| 12 | + |
| 13 | + |
| 14 | +def replicate_to_mesh( |
| 15 | + mesh: jax.sharding.Mesh, *arrs: jax.Array |
| 16 | +) -> tuple[jax.Array, ...] | jax.Array: |
| 17 | + """Replicate arrays across a mesh under explicit sharding. |
| 18 | +
|
| 19 | + JIT outputs are typically vocab/data-sharded; spec-decode host orchestration |
| 20 | + (top_k, gather, build_tree) needs replicated arrays. |
| 21 | + """ |
| 22 | + out = jax.device_put(arrs, NamedSharding(mesh, P())) |
| 23 | + return out[0] if len(out) == 1 else out |
| 24 | + |
| 25 | + |
| 26 | +class BaseDraftWorker(ABC): |
| 27 | + """Draft model worker interface for speculative decoding. |
| 28 | +
|
| 29 | + Concrete implementations hold the draft model runner and own all |
| 30 | + draft-specific logic (multi-step decode, tree building, extend). |
| 31 | + Standard EAGLE uses ``EagleDraftWorker``; MTP will use |
| 32 | + ``MultiLayerDraftWorker``. |
| 33 | + """ |
| 34 | + |
| 35 | + @abstractmethod |
| 36 | + def draft(self): |
| 37 | + pass |
| 38 | + |
| 39 | + |
| 40 | +class BaseSpecWorker(ABC): |
| 41 | + """Speculative decode orchestrator. |
| 42 | +
|
| 43 | + Owns a ``target_worker`` (the full model) and a ``draft_worker`` |
| 44 | + (the draft/MTP model). Concrete subclasses implement the main |
| 45 | + entry point and connect their specific draft/verify logic. |
| 46 | + """ |
| 47 | + |
| 48 | + @property |
| 49 | + @abstractmethod |
| 50 | + def target_worker(self) -> ModelWorker: |
| 51 | + pass |
| 52 | + |
| 53 | + @property |
| 54 | + @abstractmethod |
| 55 | + def draft_worker(self) -> BaseDraftWorker: |
| 56 | + pass |
0 commit comments