Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 111 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# SGLang Diffusion Router

A lightweight router for SGLang diffusion workers used in RL systems.
A lightweight router for SGLang diffusion workers used in RL systems. It provides worker registration, load balancing, health checking, refit weights and request proxying for diffusion generation APIs.

It provides worker registration, load balancing, health checking, refit weights and request proxying for diffusion generation APIs.
## Table of Contents

## API Reference
- [Overview](#overview)
- [Installation](#installation)
- [Quick Start](#quick-start)
- [Start diffusion workers](#start-diffusion-workers)
- [Start the router](#start-the-router)
- [Router API](#router-api)
- [Inference Endpoints](#inference-endpoints)
- [Videos Result Query](#videos-result-query)
- [Model Discovery and Health Checks](#model-discovery-and-health-checks)
- [Worker Management APIs](#worker-management-apis)
- [RL Related API](#rl-related-api)
- [Acknowledgment](#acknowledgment)

- `POST /add_worker`: add worker via query (`?url=`) or JSON body.
- `GET /list_workers`: list registered workers.
- `GET /health`: aggregated router health.
- `GET /health_workers`: per-worker health and active request counts.
- `POST /generate`: forwards to worker `/v1/images/generations`.
- `POST /generate_video`: forwards to worker `/v1/videos`; rejects image-only workers (`T2I`/`I2I`/`TI2I`) with `400`.
- `POST /update_weights_from_disk`: broadcast to all healthy workers.
- `GET|POST|PUT|DELETE /{path}`: catch-all proxy forwarding.

## Installation

Expand Down Expand Up @@ -41,7 +44,7 @@ cd ..

## Quick Start

### Co-launch workers and router via YAML config
### Co-Launch Workers and Router

Instead of starting workers manually, you can let the router spawn and manage them via a YAML config file.

Expand All @@ -59,7 +62,7 @@ launcher:
wait_timeout: 600
```

### Manual Launch Workers and Connect to Router
### Manual Launch Workers

```bash
# If connect to HuggingFace is not allowed
Expand Down Expand Up @@ -97,12 +100,26 @@ ROUTER = "http://localhost:30081"
resp = requests.get(f"{ROUTER}/health")
print(resp.json())

# List registered workers
resp = requests.get(f"{ROUTER}/list_workers")
# Register a worker
resp = requests.post(f"{ROUTER}/workers", json={"url": "http://localhost:30000"})
print(resp.json())

# List registered workers (with health/load)
resp = requests.get(f"{ROUTER}/workers")
print(resp.json())
worker_id = resp.json()["workers"][0]["worker_id"]

# Get / update worker details
resp = requests.get(f"{ROUTER}/workers/{worker_id}")
print(resp.json())
resp = requests.put(
f"{ROUTER}/workers/{worker_id}",
json={"is_dead": False, "refresh_video_support": True},
)
print(resp.json())

# Image generation request (returns base64-encoded image)
resp = requests.post(f"{ROUTER}/generate", json={
resp = requests.post(f"{ROUTER}/v1/images/generations", json={
"model": "Qwen/Qwen-Image",
"prompt": "a cute cat",
"num_images": 1,
Expand All @@ -117,10 +134,18 @@ with open("output.png", "wb") as f:
f.write(img)
print("Saved to output.png")

# Video generation request
# Note that Qwen-Image does not support video generation,
# so this request will fail.

# Check per-worker health and load
resp = requests.get(f"{ROUTER}/health_workers")
resp = requests.post(f"{ROUTER}/v1/videos", json={
"model": "Qwen/Qwen-Image",
"prompt": "a flowing river",
})
print(resp.json())
video_id = resp.json().get("video_id") or resp.json().get("id")
if video_id:
print(requests.get(f"{ROUTER}/v1/videos/{video_id}").json())

# Update weights from disk
resp = requests.post(f"{ROUTER}/update_weights_from_disk", json={
Expand All @@ -135,11 +160,16 @@ print(resp.json())
# Check router health
curl http://localhost:30081/health

# List registered workers
curl http://localhost:30081/list_workers
# Register a worker
curl -X POST http://localhost:30081/workers \
-H "Content-Type: application/json" \
-d '{"url": "http://localhost:30000"}'

# List registered workers (with health/load)
curl http://localhost:30081/workers

# Image generation request (returns base64-encoded image)
curl -X POST http://localhost:30081/generate \
curl -X POST http://localhost:30081/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen-Image",
Expand All @@ -149,7 +179,7 @@ curl -X POST http://localhost:30081/generate \
}'

# Decode and save the image locally
curl -s -X POST http://localhost:30081/generate \
curl -s -X POST http://localhost:30081/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen-Image",
Expand All @@ -165,12 +195,71 @@ with open('output.png', 'wb') as f:
print('Saved to output.png')
"

# Video generation request
curl -X POST http://localhost:30081/v1/videos \
-H "Content-Type: application/json" \
-d '{"model": "Qwen/Qwen-Image", "prompt": "a flowing river"}'

# Poll a specific video job by video_id
curl http://localhost:30081/v1/videos/{video_id}


curl -X POST http://localhost:30081/update_weights_from_disk \
-H "Content-Type: application/json" \
-d '{"model_path": "Qwen/Qwen-Image-2512"}'
```

## Router API

### Inference Endpoints

| Method | Path | Description |
|---|---|---|
| `POST` | `/v1/images/generations` | Entrypoint for text-to-image generation |
| `POST` | `/v1/videos` | Entrypoint for text-to-video generation |

### Videos Result Query

| Method | Path | Description |
|---|---|---|
| `GET` | `/v1/videos` | List or poll video jobs |
| `GET` | `/v1/videos/{video_id}` | Get status/details of a single video job |
| `GET` | `/v1/videos/{video_id}/content` | Download generated video content |

Video query routing is stable by `video_id`: router caches `video_id -> worker` on create (`POST /v1/videos`), then forwards detail/content queries to the same worker. Unknown `video_id` returns `404`.

### Model Discovery and Health Checks

| Method | Path | Description |
|---|---|---|
| `GET` | `/v1/models` | OpenAI-style model discovery |
| `GET` | `/health` | Basic health probe |

`GET /v1/models` aggregates model lists from healthy workers and de-duplicates by model `id`.

### Worker Management APIs

| Method | Path | Description |
|---|---|---|
| `POST` | `/workers` | Register a worker |
| `GET` | `/workers` | List workers (including health/load) |
| `GET` | `/workers/{worker_id}` | Get worker details |
| `PUT` | `/workers/{worker_id}` | Update worker configuration |
| `DELETE` | `/workers/{worker_id}` | Deregister a worker |

`worker_id` is the URL-encoded worker URL (machine-oriented), and each worker payload also includes `display_id` as a human-readable ID.

`PUT /workers/{worker_id}` currently supports:
- `is_dead` (boolean): quarantine (`true`) or recover (`false`) this worker.
- `refresh_video_support` (boolean): re-probe worker `/v1/models` capability.

### RL Related API

| Method | Path | Description |
|---|---|---|
| `POST` | `/update_weights_from_disk` | Reload weights from disk on all healthy workers |

Comment on lines 212 to +261
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This new API documentation is a great improvement in terms of being RESTful and well-structured. However, the code in this repository has not been updated to implement these new endpoints. The router still uses the old endpoints (e.g., /add_worker, /generate). This makes the documentation inaccurate. The implementation should be updated to match this documentation, or this documentation should be marked as a proposal for a future version.


## Acknowledgment

This project is derived from [radixark/miles#544](https://github.com/radixark/miles/pull/544). Thanks to the original authors.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ where = ["src"]

[tool.pytest.ini_options]
testpaths = ["tests/unit"]
pythonpath = ["src"]
95 changes: 81 additions & 14 deletions src/sglang_diffusion_routing/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,55 @@

import argparse
import asyncio
import os
import sys
import threading
from contextlib import suppress

from sglang_diffusion_routing import DiffusionRouter
from sglang_diffusion_routing.launcher import config as _lcfg


def _print_fired_up_banner(log_prefix: str) -> None:
banner = r"""

____ ____ _ ____ _ __ __ _
/ ___| / ___| | __ _ _ __ __ _| _ \(_)/ _|/ _|_ _ ___(_) ___ _ __
\___ \| | _| | / _` | '_ \ / _` | | | | | |_| |_| | | / __| |/ _ \| '_ \
___) | |_| | |__| (_| | | | | (_| | |_| | | _| _| |_| \__ \ | (_) | | | |
|____/ \____|_____\__,_|_| |_|\__, |____/|_|_| |_| \__,_|___/_|\___/|_| |_|
|___/

____ _ _____ _
| _ \ ___ _ _| |_ ___ ____| ___(_)_ __ ___ __| | _ _ _ __
| |_) / _ \| | | | __/ _ \ __| |_ | | '__/ _ \/ _` | | | | | '_ \
| _ < (_) | |_| | || __/ | | _| | | | | __/ (_| | | |_| | |_) |
|_| \_\___/ \__,_|\__\___|_| |_| |_|_| \___|\__,_| \__,_| .__/
|_|

"""
use_color = sys.stdout.isatty() and "NO_COLOR" not in os.environ
if not use_color:
print(f"{log_prefix} {banner}", flush=True)
return

colors = [
"\033[38;5;45m", # cyan
"\033[38;5;51m", # bright cyan
"\033[38;5;123m", # light cyan
"\033[38;5;159m", # pale blue
]
reset = "\033[0m"
colored_lines = []
for idx, line in enumerate(banner.splitlines()):
if line.strip():
color = colors[idx % len(colors)]
colored_lines.append(f"{color}{line}{reset}")
else:
colored_lines.append(line)
colored_banner = "\n".join(colored_lines)
print(f"{log_prefix} {colored_banner}", flush=True)


def _run_router_server(
args: argparse.Namespace,
router: DiffusionRouter,
Expand Down Expand Up @@ -40,15 +82,37 @@ async def _refresh_all_worker_video_support() -> None:

print(f"{log_prefix} starting router on {args.host}:{args.port}", flush=True)
print(
f"{log_prefix} workers: {list(router.worker_request_counts.keys()) or '(none - add via POST /add_worker)'}",
f"{log_prefix} workers: {list(router.worker_request_counts.keys()) or '(none - add via POST /workers)'}",
flush=True,
)
uvicorn.run(
router.app,
config = uvicorn.Config(
app=router.app,
host=args.host,
port=args.port,
log_level=getattr(args, "log_level", "info"),
)
server = uvicorn.Server(config)

async def _serve_with_banner() -> None:
banner_printed = False

async def _wait_and_print_banner() -> None:
nonlocal banner_printed
while not server.started and not server.should_exit:
await asyncio.sleep(0.1)
if server.started and not banner_printed:
banner_printed = True
_print_fired_up_banner(log_prefix)

watcher = asyncio.create_task(_wait_and_print_banner())
try:
await server.serve()
finally:
watcher.cancel()
with suppress(asyncio.CancelledError):
await watcher

asyncio.run(_serve_with_banner())


def _add_router_args(parser: argparse.ArgumentParser) -> None:
Expand Down Expand Up @@ -114,16 +178,19 @@ def _handle_router(args: argparse.Namespace) -> int:
launcher_cfg = _lcfg.load_launcher_config(args.launcher_config)
wait_timeout = launcher_cfg.wait_timeout
backend = _lcfg.create_backend(launcher_cfg)
backend.launch()
threading.Thread(
target=backend.wait_ready_and_register,
kwargs=dict(
register_func=router.register_worker,
timeout=wait_timeout,
log_prefix=log_prefix,
),
daemon=True,
).start()
launched_urls = backend.launch()
backend.wait_ready_and_register(
register_func=router.register_worker,
timeout=wait_timeout,
log_prefix=log_prefix,
)
registered_urls = set(router.worker_request_counts.keys())
pending_urls = [u for u in launched_urls if u not in registered_urls]
if pending_urls:
raise RuntimeError(
"managed workers failed to become healthy before router startup: "
+ ", ".join(pending_urls)
)

_run_router_server(args, router=router, log_prefix=log_prefix)
return 0
Expand Down
Loading