Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ repos:
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black"]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7
Expand Down
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,30 @@ curl -X POST http://localhost:30081/update_weights_from_disk \
-d '{"model_path": "Qwen/Qwen-Image-2512"}'
```

### Auto-launch workers via YAML config

Instead of starting workers manually, you can let the router spawn and manage
them through a launcher backend.

**Local subprocess launcher** (`examples/local_launcher.yaml`):

```bash
sglang-d-router --port 30081 --launcher-config examples/local_launcher.yaml
```

```yaml
launcher:
backend: local
model: Qwen/Qwen-Image
num_workers: 2
num_gpus_per_worker: 1
worker_base_port: 10090
wait_timeout: 600
```

Fields not set in the YAML fall back to defaults defined in each backend's
config dataclass (see `LocalLauncherConfig`).

## Acknowledgment

This project is derived from [radixark/miles#544](https://github.com/radixark/miles/pull/544). Thanks to the original authors.
Expand Down
13 changes: 13 additions & 0 deletions examples/local_launcher.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
launcher:
backend: local
Comment thread
dreamyang-liu marked this conversation as resolved.
Outdated
model: Qwen/Qwen-Image

num_workers: 2
num_gpus_per_worker: 2
worker_host: "127.0.0.1"
worker_base_port: 10090

# worker_gpu_ids: ["0,1", "2,3"] # optional: one entry per worker → CUDA_VISIBLE_DEVICES; auto-detected if omitted
worker_extra_args: "--dit-cpu-offload false --text-encoder-cpu-offload false"
Comment thread
dreamyang-liu marked this conversation as resolved.
Comment thread
dreamyang-liu marked this conversation as resolved.

wait_timeout: 600
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ license = { text = "MIT" }
dependencies = [
"fastapi>=0.110",
"httpx>=0.27",
"omegaconf>=2.3",
"uvicorn>=0.30",
]
classifiers = [
Expand Down
52 changes: 43 additions & 9 deletions src/sglang_diffusion_routing/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import argparse
import asyncio
import sys
import threading

from sglang_diffusion_routing import DiffusionRouter
from sglang_diffusion_routing.launcher import config as _lcfg


def _run_router_server(
args: argparse.Namespace,
worker_urls: list[str] | None = None,
router: DiffusionRouter,
log_prefix: str = "[router]",
) -> None:
try:
Expand All @@ -22,10 +24,7 @@ def _run_router_server(
"uvicorn is required to run router. Install with: pip install uvicorn"
) from exc

worker_urls = list(
worker_urls if worker_urls is not None else args.worker_urls or []
)
router = DiffusionRouter(args, verbose=args.verbose)
worker_urls = list(args.worker_urls or [])
refresh_tasks = []
for url in worker_urls:
normalized_url = router.normalize_worker_url(url)
Expand Down Expand Up @@ -97,13 +96,48 @@ def _add_router_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--log-level", type=str, default="info", help="Uvicorn log level."
)
parser.add_argument(
"--launcher-config",
type=str,
default=None,
dest="launcher_config",
help="YAML config for launching router managed workers (see examples/local_launcher.yaml).",
)


def _handle_router(args: argparse.Namespace) -> int:
_run_router_server(
args, worker_urls=list(args.worker_urls), log_prefix="[sglang-d-router]"
)
return 0
log_prefix = "[sglang-d-router]"
backend = None

try:
router = DiffusionRouter(args, verbose=args.verbose)

if args.launcher_config is not None:
launcher_cfg = _lcfg.load_launcher_config(args.launcher_config)
wait_timeout = launcher_cfg.wait_timeout
Comment thread
dreamyang-liu marked this conversation as resolved.
backend = _lcfg.create_backend(launcher_cfg)
backend.launch()
threading.Thread(
target=backend.wait_ready_and_register,
kwargs=dict(
register_fn=router.register_worker,
timeout=wait_timeout,
log_prefix=log_prefix,
),
daemon=True,
).start()

_run_router_server(args, router=router, log_prefix=log_prefix)
return 0
finally:
try:
asyncio.run(router.client.aclose())
except Exception:
pass
if backend is not None:
print(f"{log_prefix} shutting down managed workers...", flush=True)
backend.shutdown()
print(f"{log_prefix} all managed workers terminated.", flush=True)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a todo here to refactor. But we can leave it right now.



def build_parser() -> argparse.ArgumentParser:
Expand Down
26 changes: 26 additions & 0 deletions src/sglang_diffusion_routing/launcher/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Launcher backends for spinning up SGLang diffusion workers.

Supported backends:
- ``local``: launch workers as local subprocesses.
"""

from sglang_diffusion_routing.launcher.backend import (
LaunchedWorker,
LauncherBackend,
WorkerLaunchResult,
)
from sglang_diffusion_routing.launcher.config import (
create_backend,
load_launcher_config,
)
from sglang_diffusion_routing.launcher.local import LocalLauncher, LocalLauncherConfig

__all__ = [
"LaunchedWorker",
"LauncherBackend",
"LocalLauncher",
"LocalLauncherConfig",
"WorkerLaunchResult",
"create_backend",
"load_launcher_config",
]
58 changes: 58 additions & 0 deletions src/sglang_diffusion_routing/launcher/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Abstract base class and shared data types for launcher backends."""

from __future__ import annotations

import subprocess
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, field


@dataclass
class LaunchedWorker:
"""A worker managed by a launcher backend."""

url: str
process: subprocess.Popen


@dataclass
class WorkerLaunchResult:
"""Aggregated result of launching worker subprocesses."""

workers: list[LaunchedWorker] = field(default_factory=list)
all_processes: list[subprocess.Popen] = field(default_factory=list)

@property
def urls(self) -> list[str]:
return [w.url for w in self.workers]


class LauncherBackend(ABC):
"""Interface for launching and managing SGLang diffusion workers.

Each backend implements a different deployment strategy (local subprocess,
Kubernetes, Ray etc.) but exposes the same lifecycle:
launch → wait_ready_and_register → shutdown.
"""

@abstractmethod
def launch(self) -> list[str]:
"""Launch workers and return their base URLs."""

@abstractmethod
def wait_ready_and_register(
self,
register_fn: Callable[[str], None],
timeout: int,
log_prefix: str = "[launcher]",
) -> None:
"""Wait for workers to become healthy and register each via register_fn.

Workers are checked concurrently; each is registered as soon as it is
healthy rather than waiting for all workers to be ready.
"""

@abstractmethod
def shutdown(self) -> None:
"""Terminate or clean up all managed workers."""
76 changes: 76 additions & 0 deletions src/sglang_diffusion_routing/launcher/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""YAML configuration loading and backend factory."""

from __future__ import annotations

from pathlib import Path
from typing import Any

import yaml
from omegaconf import DictConfig, OmegaConf

from sglang_diffusion_routing.launcher.backend import LauncherBackend
from sglang_diffusion_routing.launcher.local import LocalLauncher, LocalLauncherConfig

SCHEMA_REGISTRY: dict[str, type] = {
"local": LocalLauncherConfig,
}

BACKEND_REGISTRY: dict[str, type[LauncherBackend]] = {
"local": LocalLauncher,
}


def load_launcher_config(config_path: str) -> DictConfig:
"""Read a YAML config file and return a validated OmegaConf config.

Steps:
1. Parse the YAML and extract the launcher mapping.
2. Read the backend key to select the structured schema.
3. Merge the YAML values onto the schema defaults.
"""
path = Path(config_path)
if not path.is_file():
raise FileNotFoundError(f"Config file not found: {config_path}")

with path.open() as f:
raw = yaml.safe_load(f)

if not isinstance(raw, dict) or "launcher" not in raw:
raise ValueError(
f"Config file must contain a top-level 'launcher' key: {config_path}"
)

launcher_raw: dict[str, Any] = raw["launcher"]
if not isinstance(launcher_raw, dict):
raise ValueError("'launcher' must be a mapping")

backend_name = launcher_raw.get("backend", "local")
schema_cls = SCHEMA_REGISTRY.get(backend_name)
if schema_cls is None:
available = ", ".join(sorted(SCHEMA_REGISTRY))
raise ValueError(
f"Unknown launcher backend: {backend_name!r}. "
f"Available backends: {available}"
)

schema = OmegaConf.structured(schema_cls)
yaml_cfg = OmegaConf.create(launcher_raw)
merged: DictConfig = OmegaConf.merge(schema, yaml_cfg) # type: ignore[assignment]
return merged


def create_backend(config: DictConfig) -> LauncherBackend:
"""Instantiate a LauncherBackend from a validated config.

The backend key selects the implementation class from
BACKEND_REGISTRY.
"""
backend_name = config.backend
cls = BACKEND_REGISTRY.get(backend_name)
if cls is None:
available = ", ".join(sorted(BACKEND_REGISTRY))
raise ValueError(
f"Unknown launcher backend: {backend_name!r}. "
f"Available backends: {available}"
)
return cls(config)
Loading