Skip to content

[Feature] Introduce hardware plugin system #20372

@22dimensions

Description

@22dimensions

Checklist


Motivation

SGLang supports CUDA, ROCm, Intel XPU, Ascend NPU, Habana HPU, and CPU. As more
hardware vendors seek integration with sglang (Moore Threads, …),
the current architecture has the following problem:

  • Harware logic scatter across many files. There is no stable
    integration contract; Take npu as an example, the only way in is to read the source, find every
    if _is_npu: guard, and add a new branch alongside it.
  • Maintaining a backend without dedicated CI is unsustainable. Backends
    that lack full CI coverage and continuous contribution slowly rot.
  • New features stall. Every generic optimization (chunked prefill, speculative
    decoding, disaggregation) must be propagated to every backend separately or
    silently broken.

Proposed Change

SGL-Diffusion already solved this problem with a Platform class and a
current_platform singleton. SRT has not. This RFC proposes:

  1. A unified Platform base class shared by both runtimes.
  2. A plugin mechanism (Python entry points) so vendors can ship support as a standalone pip install package.

The directory layout after migration from sgl-diffusion to srt:

python/sglang/
├── platforms/                      ← NEW module
│   ├── __init__.py                 current_platform singleton + plugin detection
│   ├── interface.py                Platform base class
│   ├── cuda.py                     CudaPlatform
│   ├── rocm.py                     RocmPlatform
│   ├── npu.py                      NpuPlatform 
│   ├── xpu.py                      XpuPlatform
│   ├── hpu.py                      HpuPlatform
│   └── cpu.py                      CpuPlatform
├── srt/
│   ├── platforms/
│   │   └── __init__.py            re-export from sglang.platforms (SRT alias)
│   └── hardware_backend/
│       └── npu/                    UNCHANGED — all NPU kernels/pools/runners stay here
└── multimodal_gen/runtime/
    └── platforms/
        └── __init__.py             re-export from sglang.platforms (Diffusion alias)

1. Unified Platform Class

A single Platform base class covers both runtimes. One hardware = one subclass.
Methods irrelevant to a runtime return a safe default and are never called.

# sglang/platforms/interface.py

class Platform:
    """
    Unified hardware abstraction for SRT and SGL-Diffusion.
    Subclass this and publish via entry point to integrate a new device.
    """
    _enum: PlatformEnum          # CUDA | ROCM | NPU | XPU | HPU | CPU | OOT
    device_name: str             # "cuda" | "npu" | "musa" | …
    device_type: str
    dispatch_key: str = "CPU"

    # ── Type predicates (lru_cached) ─────────────────────────────────────
    def is_cuda(self) -> bool: ...
    def is_rocm(self) -> bool: ...
    def is_npu(self) -> bool: ...
    def is_xpu(self) -> bool: ...
    def is_hpu(self) -> bool: ...
    def is_cpu(self) -> bool: ...
    def is_out_of_tree(self) -> bool: ...
    def is_cuda_alike(self) -> bool: ...   # True for CUDA + ROCm

    # ── Hardware introspection (shared) ──────────────────────────────────
    def get_device_name(self, device_id: int = 0) -> str: ...
    def get_device_total_memory(self, device_id: int = 0) -> int: ...
    def get_device_capability(self, device_id: int = 0): ...
    def get_available_memory(self, device_id: int = 0) -> float: ...
    def get_torch_distributed_backend_str(self) -> str: ...
    def get_device_communicator_cls(self) -> Optional[str]: ...

    # ── Lifecycle (shared) ───────────────────────────────────────────────
    def initialize_device(self, rank: int, local_rank: int) -> None: pass

    # ── SRT-specific (default: no-op / None / {}) ────────────────────────
    def apply_server_args_defaults(self, args) -> None: pass
    def get_attention_backends(self) -> Dict[str, Callable]: return {}
    def get_default_attention_backend(self) -> Optional[str]: return None
    def get_graph_runner_class(self) -> Optional[Type]: return None
    def create_kv_pool(self, runner, **kwargs): return None
    def get_allocator_class(self) -> Optional[Type]: return None
    def get_quantization_methods(self) -> Dict[str, Type]: return {}
    def get_moe_a2a_backends(self) -> Dict[str, str]: return {}
    def get_disaggregation_backends(self) -> Dict[str, Type]: return {}

    # ── Diffusion-specific (default: sensible fallback) ──────────────────
    def get_attn_backend_cls_str(self, backend, head_size, dtype) -> str: return ""
    def is_amp_supported(self) -> bool: return True
    def enable_dit_layerwise_offload(self) -> bool: return True
    def seed_everything(self, seed: int) -> None: ...
    def verify_model_arch(self, arch: str) -> None: pass
    def verify_quantization(self, quant: str) -> None: pass

2. Plugin Registration: Entry Points

Vendors ship a Python package and declare their Platform subclass as a
PEP 517 entry point:

# pyproject.toml of the vendor package (e.g. sglang-musa)
[project.entry-points."sglang.platform_plugins"]
musa = "sglang_musa.platform:MusaPlatform"

After pip install sglang-musa, SGLang discovers the plugin automatically on
the next startup — no upstream changes required.

For development without packaging, set:

python -m sglang.launch_server ...

3. Class Hierarchy

classDiagram
    class Platform {
        <<abstract>>
        +device_name: str
        +_enum: PlatformEnum
        +is_cuda() bool
        +is_rocm() bool
        +is_npu() bool
        +is_out_of_tree() bool
        +initialize_device(rank, local_rank)
        +get_attention_backends() Dict
        +get_graph_runner_class() Type
        +create_kv_pool(runner) Pool
        +apply_server_args_defaults(args)
        +get_attn_backend_cls_str(...) str
        +get_device_communicator_cls() str
    }

    class PlatformEnum {
        <<enumeration>>
        CUDA
        ROCM
        NPU
        XPU
        HPU
        CPU
        OOT
    }

    class CudaPlatform {
        +get_attention_backends() flashinfer/fa3/triton/...
        +get_graph_runner_class() CudaGraphRunner
        +create_kv_pool() MHATokenToKVPool
        +get_attn_backend_cls_str() FlashAttn/SDPA
    }

    class RocmPlatform {
        +get_attention_backends() aiter/wave/flashinfer
        +get_graph_runner_class() CudaGraphRunner
        +get_attn_backend_cls_str() AITER/SDPA
    }

    class NpuPlatform {
        +initialize_device() init_npu_backend()
        +apply_server_args_defaults()
        +get_attention_backends() ascend
        +get_graph_runner_class() NPUGraphRunner
        +create_kv_pool() NPUMHATokenToKVPool
    }

    class XpuPlatform
    class HpuPlatform
    class CpuPlatform

    class OotPlatform {
        <<vendor package>>
        _enum = OOT
        implements only needed methods
    }

    Platform --> PlatformEnum
    Platform <|-- CudaPlatform
    Platform <|-- RocmPlatform
    Platform <|-- NpuPlatform
    Platform <|-- XpuPlatform
    Platform <|-- HpuPlatform
    Platform <|-- CpuPlatform
    Platform <|-- OotPlatform
Loading

Writing a Platform Plugin

A vendor ships one Python package that works for both SRT and SGL-Diffusion.

Package layout

sglang_<vendor>/
├── platform.py            ← Platform subclass (only required file)
├── attention/
│   └── vendor_attn.py
├── graph_runner/
│   └── vendor_runner.py
├── memory/
│   └── vendor_pool.py
└── distributed/
    └── vendor_comm.py

# pyproject.toml
[project.entry-points."sglang.platform_plugins"]
<device> = "sglang_<vendor>.platform:VendorPlatform"

Minimal example (Moore Threads MUSA)

# sglang_musa/platform.py
from sglang.platforms import Platform, PlatformEnum

class MusaPlatform(Platform):
    _enum       = PlatformEnum.OOT
    device_name = "musa"
    device_type = "musa"

    def is_available(self) -> bool:
        try:
            import torch_musa
            return torch_musa.is_available()
        except ImportError:
            return False

    def initialize_device(self, rank, local_rank):
        import torch_musa
        torch_musa.set_device(local_rank)

    def get_device_communicator_cls(self):
        return "sglang_musa.distributed.MCCLCommunicator"

    def get_torch_distributed_backend_str(self):
        return "mccl"

    def apply_server_args_defaults(self, args):
        if args.attention_backend is None:
            args.attention_backend = "musa_flash"

    def get_attention_backends(self):
        from sglang_musa.attention import MusaFlashAttn
        return {"musa_flash": lambda r: MusaFlashAttn(r)}

    def get_graph_runner_class(self):
        from sglang_musa.graph_runner import MusaGraphRunner
        return MusaGraphRunner

    def create_kv_pool(self, runner, **kw):
        from sglang_musa.memory import MusaMHAPool
        return MusaMHAPool(runner, **kw)

    def get_attn_backend_cls_str(self, backend, head_size, dtype):
        return "sglang_musa.diffusion_attn.MusaFlashAttnBackend"

    def is_amp_supported(self):
        return True

Installation and activation

# Production — auto-activated on MUSA hardware after pip install
pip install sglang sglang-musa

# Development — no install required
    python -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct

Plan

Phase 1 — Create sglang platforms (Q1)

  • Create sglang/platforms/interface.py — Platform base class
  • Create sglang/platforms/{cuda,rocm,npu,xpu,hpu,cpu}.py
  • Create sglang/platforms/__init__.pycurrent_platform + plugin detection
  • Declare sglang.platform_plugins entry point group in pyproject.toml
  • Unit tests with a MockPlatform

Phase 2 — Migrate SGL-Diffusion (Q1)

  • multimodal_gen/runtime/platforms/interface.py: inherit from sglang.platforms.Platform
  • multimodal_gen/runtime/platforms/__init__.py: thin re-export
  • All Diffusion behavior preserved, detection logic unified

Phase 3 — Wire SRT Core (Q2)

Replace 5 extension points with current_platform calls:

File Change
model_executor/model_runner.py initialize_device() + get_graph_runner_class()
model_executor/model_runner_kv_cache_mixin.py create_kv_pool()
server_args.py apply_server_args_defaults()
layers/attention/attention_registry.py get_attention_backends()
distributed/parallel_state.py get_device_communicator_cls()

Built-in CUDA/ROCm/NPU paths unchanged — Platform methods delegate to the
same code that runs today.

Phase 4 — Clean Up is_xxx() Guards (Q2)

~40 call-sites across SRT → current_platform.is_npu().
Existing is_npu() helpers in utils/common.py become thin wrappers that
delegate to current_platform for backward compatibility.

Phase 5 — Documentation + Vendor Template (Q2)

  • docs/developer_guide/platform_plugin.md — step-by-step vendor guide
  • sglang-platform-template skeleton repository on GitHub
  • Add Tier 2 plugin table to docs/supported_models/hardware_support.md

Related resources

No response

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions