Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ repos:
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black"]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7
Expand Down
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,30 @@ curl -X POST http://localhost:30081/update_weights_from_disk \
-d '{"model_path": "Qwen/Qwen-Image-2512"}'
```

### Auto-launch workers via YAML config

Instead of starting workers manually, you can let the router spawn and manage
them through a launcher backend.

**Local subprocess launcher** (`examples/local_launcher.yaml`):

```bash
sglang-d-router --port 30081 --launcher-config examples/local_launcher.yaml
```

```yaml
launcher:
backend: local
model: Qwen/Qwen-Image
num_workers: 2
num_gpus_per_worker: 1
worker_base_port: 10090
wait_timeout: 600
```

Fields not set in the YAML fall back to defaults defined in each backend's
config dataclass (see `LocalLauncherConfig`).

## Acknowledgment

This project is derived from [radixark/miles#544](https://github.com/radixark/miles/pull/544). Thanks to the original authors.
Expand Down
2 changes: 1 addition & 1 deletion development.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
pip install -e .
```

Run tests:
Run CPU only tests:

```bash
pip install pytest
Expand Down
11 changes: 11 additions & 0 deletions examples/local_launcher.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
launcher:
model: Qwen/Qwen-Image

num_workers: 8
num_gpus_per_worker: 1
worker_host: "127.0.0.1"
worker_base_port: 10090

worker_extra_args: "--dit-cpu-offload false --text-encoder-cpu-offload false"
Comment thread
dreamyang-liu marked this conversation as resolved.
Comment thread
dreamyang-liu marked this conversation as resolved.

wait_timeout: 600
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ license = { text = "MIT" }
dependencies = [
"fastapi>=0.110",
"httpx>=0.27",
"omegaconf>=2.3",
"uvicorn>=0.30",
]
classifiers = [
Expand Down
56 changes: 47 additions & 9 deletions src/sglang_diffusion_routing/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import argparse
import asyncio
import sys
import threading

from sglang_diffusion_routing import DiffusionRouter
from sglang_diffusion_routing.launcher import config as _lcfg


def _run_router_server(
args: argparse.Namespace,
worker_urls: list[str] | None = None,
router: DiffusionRouter,
log_prefix: str = "[router]",
) -> None:
try:
Expand All @@ -22,10 +24,7 @@ def _run_router_server(
"uvicorn is required to run router. Install with: pip install uvicorn"
) from exc

worker_urls = list(
worker_urls if worker_urls is not None else args.worker_urls or []
)
router = DiffusionRouter(args, verbose=args.verbose)
worker_urls = list(args.worker_urls or [])
refresh_tasks = []
for url in worker_urls:
normalized_url = router.normalize_worker_url(url)
Expand Down Expand Up @@ -97,13 +96,52 @@ def _add_router_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--log-level", type=str, default="info", help="Uvicorn log level."
)
parser.add_argument(
"--launcher-config",
type=str,
default=None,
dest="launcher_config",
help="YAML config for launching router managed workers (see examples/local_launcher.yaml).",
)


def _handle_router(args: argparse.Namespace) -> int:
_run_router_server(
args, worker_urls=list(args.worker_urls), log_prefix="[sglang-d-router]"
)
return 0
log_prefix, backend, router = "[sglang-d-router]", None, None

try:
router = DiffusionRouter(args, verbose=args.verbose)
if args.launcher_config is not None:
launcher_cfg = _lcfg.load_launcher_config(args.launcher_config)
wait_timeout = launcher_cfg.wait_timeout
Comment thread
dreamyang-liu marked this conversation as resolved.
backend = _lcfg.create_backend(launcher_cfg)
backend.launch()
threading.Thread(
target=backend.wait_ready_and_register,
kwargs=dict(
register_func=router.register_worker,
timeout=wait_timeout,
log_prefix=log_prefix,
),
daemon=True,
).start()

_run_router_server(args, router=router, log_prefix=log_prefix)
return 0
finally:
# TODO (mengyang, shuwen, chenyang): refactor the exit logic of router and backend.
if router is not None:
try:
asyncio.run(router.client.aclose())
except Exception as exc:
print(
f"{log_prefix} warning: failed to close router client: {exc}",
file=sys.stderr,
flush=True,
)
if backend is not None:
print(f"{log_prefix} shutting down managed workers...", flush=True)
backend.shutdown()
print(f"{log_prefix} all managed workers terminated.", flush=True)


def build_parser() -> argparse.ArgumentParser:
Expand Down
26 changes: 26 additions & 0 deletions src/sglang_diffusion_routing/launcher/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Launcher backends for spinning up SGLang diffusion workers.

Right now only supports local backend, which launches workers as local subprocesses.
We leave this module for future extensions on slurm or kubernetes.
"""

from sglang_diffusion_routing.launcher.backend import (
LaunchedWorker,
LauncherBackend,
WorkerLaunchResult,
)
from sglang_diffusion_routing.launcher.config import (
create_backend,
load_launcher_config,
)
from sglang_diffusion_routing.launcher.local import LocalLauncher, LocalLauncherConfig

__all__ = [
"LaunchedWorker",
"LauncherBackend",
"LocalLauncher",
"LocalLauncherConfig",
"WorkerLaunchResult",
"create_backend",
"load_launcher_config",
]
53 changes: 53 additions & 0 deletions src/sglang_diffusion_routing/launcher/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Abstract base class and shared data types for launcher backends."""

from __future__ import annotations

import subprocess
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, field


@dataclass
class LaunchedWorker:
"""A worker managed by a launcher backend."""

url: str
process: subprocess.Popen


@dataclass
class WorkerLaunchResult:
"""Aggregated result of launching worker subprocesses."""

workers: list[LaunchedWorker] = field(default_factory=list)
all_processes: list[subprocess.Popen] = field(default_factory=list)

@property
def urls(self) -> list[str]:
return [w.url for w in self.workers]


class LauncherBackend(ABC):
"""Interface for launching and managing SGLang diffusion workers.

Each backend exposes the same lifecycle:
launch → wait_ready_and_register → shutdown.
"""

@abstractmethod
def launch(self) -> list[str]:
"""Launch workers and return their base URLs."""

@abstractmethod
def wait_ready_and_register(
self,
register_func: Callable[[str], None],
timeout: int,
log_prefix: str = "[launcher]",
) -> None:
"""Wait for workers to become healthy and register each via register_func."""

@abstractmethod
def shutdown(self) -> None:
"""Clean up all managed workers."""
70 changes: 70 additions & 0 deletions src/sglang_diffusion_routing/launcher/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""YAML configuration loading and backend factory."""

from __future__ import annotations

from pathlib import Path

import yaml
from omegaconf import DictConfig, OmegaConf

from sglang_diffusion_routing.launcher.backend import LauncherBackend
from sglang_diffusion_routing.launcher.local import LocalLauncher, LocalLauncherConfig

SCHEMA_REGISTRY: dict[str, type] = {
"local": LocalLauncherConfig,
}

BACKEND_REGISTRY: dict[str, type[LauncherBackend]] = {
"local": LocalLauncher,
}


def load_launcher_config(config_path: str) -> DictConfig:
"""Read a YAML config file and return a validated OmegaConf config.

1. Parse the YAML and extract the launcher mapping.
2. Read the backend key to select the structured schema.
3. Merge the YAML values onto the schema defaults.
"""
path = Path(config_path)
if not path.is_file():
raise FileNotFoundError(f"Config file not found: {config_path}")

with path.open() as f:
raw = yaml.safe_load(f)

if not isinstance(raw, dict) or "launcher" not in raw:
raise ValueError(
f"Config file must contain a top-level 'launcher' key: {config_path}"
)

launcher_raw = raw["launcher"]
if not isinstance(launcher_raw, dict):
raise ValueError("'launcher' must be a dictionary")

backend_name = launcher_raw.get("backend", "local")
schema_cls = SCHEMA_REGISTRY.get(backend_name)
if schema_cls is None:
available = ", ".join(sorted(SCHEMA_REGISTRY))
raise ValueError(
f"Unknown launcher backend: {backend_name!r}. "
f"Available backends: {available}"
)

schema = OmegaConf.structured(schema_cls)
yaml_cfg = OmegaConf.create(launcher_raw)
merged = OmegaConf.merge(schema, yaml_cfg)
return merged


def create_backend(config: DictConfig) -> LauncherBackend:
"""Instantiate a LauncherBackend from a validated config."""
backend_name = config.backend
cls = BACKEND_REGISTRY.get(backend_name)
if cls is None:
available = ", ".join(sorted(BACKEND_REGISTRY))
raise ValueError(
f"Unknown launcher backend: {backend_name!r}. "
f"Available backends: {available}"
)
return cls(config)
Loading