-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Description
Checklist
- If this is not a feature request but a general question, please start a discussion at https://github.com/sgl-project/sglang/discussions. Otherwise, it will be closed.
- Please use English. Otherwise, it will be closed.
Motivation
SGLang supports CUDA, ROCm, Intel XPU, Ascend NPU, Habana HPU, and CPU. As more
hardware vendors seek integration with sglang (Moore Threads, …),
the current architecture has the following problem:
- Harware logic scatter across many files. There is no stable
integration contract; Take npu as an example, the only way in is to read the source, find every
if _is_npu:guard, and add a new branch alongside it. - Maintaining a backend without dedicated CI is unsustainable. Backends
that lack full CI coverage and continuous contribution slowly rot. - New features stall. Every generic optimization (chunked prefill, speculative
decoding, disaggregation) must be propagated to every backend separately or
silently broken.
Proposed Change
SGL-Diffusion already solved this problem with a Platform class and a
current_platform singleton. SRT has not. This RFC proposes:
- A unified
Platformbase class shared by both runtimes. - A plugin mechanism (Python entry points) so vendors can ship support as a standalone
pip installpackage.
The directory layout after migration from sgl-diffusion to srt:
python/sglang/
├── platforms/ ← NEW module
│ ├── __init__.py current_platform singleton + plugin detection
│ ├── interface.py Platform base class
│ ├── cuda.py CudaPlatform
│ ├── rocm.py RocmPlatform
│ ├── npu.py NpuPlatform
│ ├── xpu.py XpuPlatform
│ ├── hpu.py HpuPlatform
│ └── cpu.py CpuPlatform
├── srt/
│ ├── platforms/
│ │ └── __init__.py re-export from sglang.platforms (SRT alias)
│ └── hardware_backend/
│ └── npu/ UNCHANGED — all NPU kernels/pools/runners stay here
└── multimodal_gen/runtime/
└── platforms/
└── __init__.py re-export from sglang.platforms (Diffusion alias)
1. Unified Platform Class
A single Platform base class covers both runtimes. One hardware = one subclass.
Methods irrelevant to a runtime return a safe default and are never called.
# sglang/platforms/interface.py
class Platform:
"""
Unified hardware abstraction for SRT and SGL-Diffusion.
Subclass this and publish via entry point to integrate a new device.
"""
_enum: PlatformEnum # CUDA | ROCM | NPU | XPU | HPU | CPU | OOT
device_name: str # "cuda" | "npu" | "musa" | …
device_type: str
dispatch_key: str = "CPU"
# ── Type predicates (lru_cached) ─────────────────────────────────────
def is_cuda(self) -> bool: ...
def is_rocm(self) -> bool: ...
def is_npu(self) -> bool: ...
def is_xpu(self) -> bool: ...
def is_hpu(self) -> bool: ...
def is_cpu(self) -> bool: ...
def is_out_of_tree(self) -> bool: ...
def is_cuda_alike(self) -> bool: ... # True for CUDA + ROCm
# ── Hardware introspection (shared) ──────────────────────────────────
def get_device_name(self, device_id: int = 0) -> str: ...
def get_device_total_memory(self, device_id: int = 0) -> int: ...
def get_device_capability(self, device_id: int = 0): ...
def get_available_memory(self, device_id: int = 0) -> float: ...
def get_torch_distributed_backend_str(self) -> str: ...
def get_device_communicator_cls(self) -> Optional[str]: ...
# ── Lifecycle (shared) ───────────────────────────────────────────────
def initialize_device(self, rank: int, local_rank: int) -> None: pass
# ── SRT-specific (default: no-op / None / {}) ────────────────────────
def apply_server_args_defaults(self, args) -> None: pass
def get_attention_backends(self) -> Dict[str, Callable]: return {}
def get_default_attention_backend(self) -> Optional[str]: return None
def get_graph_runner_class(self) -> Optional[Type]: return None
def create_kv_pool(self, runner, **kwargs): return None
def get_allocator_class(self) -> Optional[Type]: return None
def get_quantization_methods(self) -> Dict[str, Type]: return {}
def get_moe_a2a_backends(self) -> Dict[str, str]: return {}
def get_disaggregation_backends(self) -> Dict[str, Type]: return {}
# ── Diffusion-specific (default: sensible fallback) ──────────────────
def get_attn_backend_cls_str(self, backend, head_size, dtype) -> str: return ""
def is_amp_supported(self) -> bool: return True
def enable_dit_layerwise_offload(self) -> bool: return True
def seed_everything(self, seed: int) -> None: ...
def verify_model_arch(self, arch: str) -> None: pass
def verify_quantization(self, quant: str) -> None: pass2. Plugin Registration: Entry Points
Vendors ship a Python package and declare their Platform subclass as a
PEP 517 entry point:
# pyproject.toml of the vendor package (e.g. sglang-musa)
[project.entry-points."sglang.platform_plugins"]
musa = "sglang_musa.platform:MusaPlatform"After pip install sglang-musa, SGLang discovers the plugin automatically on
the next startup — no upstream changes required.
For development without packaging, set:
python -m sglang.launch_server ...3. Class Hierarchy
classDiagram
class Platform {
<<abstract>>
+device_name: str
+_enum: PlatformEnum
+is_cuda() bool
+is_rocm() bool
+is_npu() bool
+is_out_of_tree() bool
+initialize_device(rank, local_rank)
+get_attention_backends() Dict
+get_graph_runner_class() Type
+create_kv_pool(runner) Pool
+apply_server_args_defaults(args)
+get_attn_backend_cls_str(...) str
+get_device_communicator_cls() str
}
class PlatformEnum {
<<enumeration>>
CUDA
ROCM
NPU
XPU
HPU
CPU
OOT
}
class CudaPlatform {
+get_attention_backends() flashinfer/fa3/triton/...
+get_graph_runner_class() CudaGraphRunner
+create_kv_pool() MHATokenToKVPool
+get_attn_backend_cls_str() FlashAttn/SDPA
}
class RocmPlatform {
+get_attention_backends() aiter/wave/flashinfer
+get_graph_runner_class() CudaGraphRunner
+get_attn_backend_cls_str() AITER/SDPA
}
class NpuPlatform {
+initialize_device() init_npu_backend()
+apply_server_args_defaults()
+get_attention_backends() ascend
+get_graph_runner_class() NPUGraphRunner
+create_kv_pool() NPUMHATokenToKVPool
}
class XpuPlatform
class HpuPlatform
class CpuPlatform
class OotPlatform {
<<vendor package>>
_enum = OOT
implements only needed methods
}
Platform --> PlatformEnum
Platform <|-- CudaPlatform
Platform <|-- RocmPlatform
Platform <|-- NpuPlatform
Platform <|-- XpuPlatform
Platform <|-- HpuPlatform
Platform <|-- CpuPlatform
Platform <|-- OotPlatform
Writing a Platform Plugin
A vendor ships one Python package that works for both SRT and SGL-Diffusion.
Package layout
sglang_<vendor>/
├── platform.py ← Platform subclass (only required file)
├── attention/
│ └── vendor_attn.py
├── graph_runner/
│ └── vendor_runner.py
├── memory/
│ └── vendor_pool.py
└── distributed/
└── vendor_comm.py
# pyproject.toml
[project.entry-points."sglang.platform_plugins"]
<device> = "sglang_<vendor>.platform:VendorPlatform"
Minimal example (Moore Threads MUSA)
# sglang_musa/platform.py
from sglang.platforms import Platform, PlatformEnum
class MusaPlatform(Platform):
_enum = PlatformEnum.OOT
device_name = "musa"
device_type = "musa"
def is_available(self) -> bool:
try:
import torch_musa
return torch_musa.is_available()
except ImportError:
return False
def initialize_device(self, rank, local_rank):
import torch_musa
torch_musa.set_device(local_rank)
def get_device_communicator_cls(self):
return "sglang_musa.distributed.MCCLCommunicator"
def get_torch_distributed_backend_str(self):
return "mccl"
def apply_server_args_defaults(self, args):
if args.attention_backend is None:
args.attention_backend = "musa_flash"
def get_attention_backends(self):
from sglang_musa.attention import MusaFlashAttn
return {"musa_flash": lambda r: MusaFlashAttn(r)}
def get_graph_runner_class(self):
from sglang_musa.graph_runner import MusaGraphRunner
return MusaGraphRunner
def create_kv_pool(self, runner, **kw):
from sglang_musa.memory import MusaMHAPool
return MusaMHAPool(runner, **kw)
def get_attn_backend_cls_str(self, backend, head_size, dtype):
return "sglang_musa.diffusion_attn.MusaFlashAttnBackend"
def is_amp_supported(self):
return TrueInstallation and activation
# Production — auto-activated on MUSA hardware after pip install
pip install sglang sglang-musa
# Development — no install required
python -m sglang.launch_server --model meta-llama/Llama-3.1-8B-InstructPlan
Phase 1 — Create sglang platforms (Q1)
- Create
sglang/platforms/interface.py— Platform base class - Create
sglang/platforms/{cuda,rocm,npu,xpu,hpu,cpu}.py - Create
sglang/platforms/__init__.py—current_platform+ plugin detection - Declare
sglang.platform_pluginsentry point group inpyproject.toml - Unit tests with a
MockPlatform
Phase 2 — Migrate SGL-Diffusion (Q1)
multimodal_gen/runtime/platforms/interface.py: inherit fromsglang.platforms.Platformmultimodal_gen/runtime/platforms/__init__.py: thin re-export- All Diffusion behavior preserved, detection logic unified
Phase 3 — Wire SRT Core (Q2)
Replace 5 extension points with current_platform calls:
| File | Change |
|---|---|
model_executor/model_runner.py |
initialize_device() + get_graph_runner_class() |
model_executor/model_runner_kv_cache_mixin.py |
create_kv_pool() |
server_args.py |
apply_server_args_defaults() |
layers/attention/attention_registry.py |
get_attention_backends() |
distributed/parallel_state.py |
get_device_communicator_cls() |
Built-in CUDA/ROCm/NPU paths unchanged — Platform methods delegate to the
same code that runs today.
Phase 4 — Clean Up is_xxx() Guards (Q2)
~40 call-sites across SRT → current_platform.is_npu().
Existing is_npu() helpers in utils/common.py become thin wrappers that
delegate to current_platform for backward compatibility.
Phase 5 — Documentation + Vendor Template (Q2)
docs/developer_guide/platform_plugin.md— step-by-step vendor guidesglang-platform-templateskeleton repository on GitHub- Add Tier 2 plugin table to
docs/supported_models/hardware_support.md
Related resources
No response