Skip to content

Commit 74604ca

Browse files
Merge pull request zhaochenyang20#17 from alphabetc1/refactor/api
[RFC] refactor: RESTful and sgl/smg compliant API
2 parents bf11ef6 + 9a78c3e commit 74604ca

5 files changed

Lines changed: 892 additions & 103 deletions

File tree

README.md

Lines changed: 111 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
# SGLang Diffusion Router
22

3-
A lightweight router for SGLang diffusion workers used in RL systems.
3+
A lightweight router for SGLang diffusion workers used in RL systems. It provides worker registration, load balancing, health checking, refit weights and request proxying for diffusion generation APIs.
44

5-
It provides worker registration, load balancing, health checking, refit weights and request proxying for diffusion generation APIs.
5+
## Table of Contents
66

7-
## API Reference
7+
- [Overview](#overview)
8+
- [Installation](#installation)
9+
- [Quick Start](#quick-start)
10+
- [Start diffusion workers](#start-diffusion-workers)
11+
- [Start the router](#start-the-router)
12+
- [Router API](#router-api)
13+
- [Inference Endpoints](#inference-endpoints)
14+
- [Videos Result Query](#videos-result-query)
15+
- [Model Discovery and Health Checks](#model-discovery-and-health-checks)
16+
- [Worker Management APIs](#worker-management-apis)
17+
- [RL Related API](#rl-related-api)
18+
- [Acknowledgment](#acknowledgment)
819

9-
- `POST /add_worker`: add worker via query (`?url=`) or JSON body.
10-
- `GET /list_workers`: list registered workers.
11-
- `GET /health`: aggregated router health.
12-
- `GET /health_workers`: per-worker health and active request counts.
13-
- `POST /generate`: forwards to worker `/v1/images/generations`.
14-
- `POST /generate_video`: forwards to worker `/v1/videos`; rejects image-only workers (`T2I`/`I2I`/`TI2I`) with `400`.
15-
- `POST /update_weights_from_disk`: broadcast to all healthy workers.
16-
- `GET|POST|PUT|DELETE /{path}`: catch-all proxy forwarding.
1720

1821
## Installation
1922

@@ -41,7 +44,7 @@ cd ..
4144

4245
## Quick Start
4346

44-
### Co-launch workers and router via YAML config
47+
### Co-Launch Workers and Router
4548

4649
Instead of starting workers manually, you can let the router spawn and manage them via a YAML config file.
4750

@@ -59,7 +62,7 @@ launcher:
5962
wait_timeout: 600
6063
```
6164
62-
### Manual Launch Workers and Connect to Router
65+
### Manual Launch Workers
6366
6467
```bash
6568
# If connect to HuggingFace is not allowed
@@ -97,12 +100,26 @@ ROUTER = "http://localhost:30081"
97100
resp = requests.get(f"{ROUTER}/health")
98101
print(resp.json())
99102

100-
# List registered workers
101-
resp = requests.get(f"{ROUTER}/list_workers")
103+
# Register a worker
104+
resp = requests.post(f"{ROUTER}/workers", json={"url": "http://localhost:30000"})
105+
print(resp.json())
106+
107+
# List registered workers (with health/load)
108+
resp = requests.get(f"{ROUTER}/workers")
109+
print(resp.json())
110+
worker_id = resp.json()["workers"][0]["worker_id"]
111+
112+
# Get / update worker details
113+
resp = requests.get(f"{ROUTER}/workers/{worker_id}")
114+
print(resp.json())
115+
resp = requests.put(
116+
f"{ROUTER}/workers/{worker_id}",
117+
json={"is_dead": False, "refresh_video_support": True},
118+
)
102119
print(resp.json())
103120

104121
# Image generation request (returns base64-encoded image)
105-
resp = requests.post(f"{ROUTER}/generate", json={
122+
resp = requests.post(f"{ROUTER}/v1/images/generations", json={
106123
"model": "Qwen/Qwen-Image",
107124
"prompt": "a cute cat",
108125
"num_images": 1,
@@ -117,10 +134,18 @@ with open("output.png", "wb") as f:
117134
f.write(img)
118135
print("Saved to output.png")
119136

137+
# Video generation request
138+
# Note that Qwen-Image does not support video generation,
139+
# so this request will fail.
120140

121-
# Check per-worker health and load
122-
resp = requests.get(f"{ROUTER}/health_workers")
141+
resp = requests.post(f"{ROUTER}/v1/videos", json={
142+
"model": "Qwen/Qwen-Image",
143+
"prompt": "a flowing river",
144+
})
123145
print(resp.json())
146+
video_id = resp.json().get("video_id") or resp.json().get("id")
147+
if video_id:
148+
print(requests.get(f"{ROUTER}/v1/videos/{video_id}").json())
124149

125150
# Update weights from disk
126151
resp = requests.post(f"{ROUTER}/update_weights_from_disk", json={
@@ -135,11 +160,16 @@ print(resp.json())
135160
# Check router health
136161
curl http://localhost:30081/health
137162

138-
# List registered workers
139-
curl http://localhost:30081/list_workers
163+
# Register a worker
164+
curl -X POST http://localhost:30081/workers \
165+
-H "Content-Type: application/json" \
166+
-d '{"url": "http://localhost:30000"}'
167+
168+
# List registered workers (with health/load)
169+
curl http://localhost:30081/workers
140170

141171
# Image generation request (returns base64-encoded image)
142-
curl -X POST http://localhost:30081/generate \
172+
curl -X POST http://localhost:30081/v1/images/generations \
143173
-H "Content-Type: application/json" \
144174
-d '{
145175
"model": "Qwen/Qwen-Image",
@@ -149,7 +179,7 @@ curl -X POST http://localhost:30081/generate \
149179
}'
150180

151181
# Decode and save the image locally
152-
curl -s -X POST http://localhost:30081/generate \
182+
curl -s -X POST http://localhost:30081/v1/images/generations \
153183
-H "Content-Type: application/json" \
154184
-d '{
155185
"model": "Qwen/Qwen-Image",
@@ -165,12 +195,71 @@ with open('output.png', 'wb') as f:
165195
print('Saved to output.png')
166196
"
167197

198+
# Video generation request
199+
curl -X POST http://localhost:30081/v1/videos \
200+
-H "Content-Type: application/json" \
201+
-d '{"model": "Qwen/Qwen-Image", "prompt": "a flowing river"}'
202+
203+
# Poll a specific video job by video_id
204+
curl http://localhost:30081/v1/videos/{video_id}
205+
168206

169207
curl -X POST http://localhost:30081/update_weights_from_disk \
170208
-H "Content-Type: application/json" \
171209
-d '{"model_path": "Qwen/Qwen-Image-2512"}'
172210
```
173211

212+
## Router API
213+
214+
### Inference Endpoints
215+
216+
| Method | Path | Description |
217+
|---|---|---|
218+
| `POST` | `/v1/images/generations` | Entrypoint for text-to-image generation |
219+
| `POST` | `/v1/videos` | Entrypoint for text-to-video generation |
220+
221+
### Videos Result Query
222+
223+
| Method | Path | Description |
224+
|---|---|---|
225+
| `GET` | `/v1/videos` | List or poll video jobs |
226+
| `GET` | `/v1/videos/{video_id}` | Get status/details of a single video job |
227+
| `GET` | `/v1/videos/{video_id}/content` | Download generated video content |
228+
229+
Video query routing is stable by `video_id`: router caches `video_id -> worker` on create (`POST /v1/videos`), then forwards detail/content queries to the same worker. Unknown `video_id` returns `404`.
230+
231+
### Model Discovery and Health Checks
232+
233+
| Method | Path | Description |
234+
|---|---|---|
235+
| `GET` | `/v1/models` | OpenAI-style model discovery |
236+
| `GET` | `/health` | Basic health probe |
237+
238+
`GET /v1/models` aggregates model lists from healthy workers and de-duplicates by model `id`.
239+
240+
### Worker Management APIs
241+
242+
| Method | Path | Description |
243+
|---|---|---|
244+
| `POST` | `/workers` | Register a worker |
245+
| `GET` | `/workers` | List workers (including health/load) |
246+
| `GET` | `/workers/{worker_id}` | Get worker details |
247+
| `PUT` | `/workers/{worker_id}` | Update worker configuration |
248+
| `DELETE` | `/workers/{worker_id}` | Deregister a worker |
249+
250+
`worker_id` is the URL-encoded worker URL (machine-oriented), and each worker payload also includes `display_id` as a human-readable ID.
251+
252+
`PUT /workers/{worker_id}` currently supports:
253+
- `is_dead` (boolean): quarantine (`true`) or recover (`false`) this worker.
254+
- `refresh_video_support` (boolean): re-probe worker `/v1/models` capability.
255+
256+
### RL Related API
257+
258+
| Method | Path | Description |
259+
|---|---|---|
260+
| `POST` | `/update_weights_from_disk` | Reload weights from disk on all healthy workers |
261+
262+
174263
## Acknowledgment
175264

176265
This project is derived from [radixark/miles#544](https://github.com/radixark/miles/pull/544). Thanks to the original authors.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,4 @@ where = ["src"]
4141

4242
[tool.pytest.ini_options]
4343
testpaths = ["tests/unit"]
44+
pythonpath = ["src"]

src/sglang_diffusion_routing/cli/main.py

Lines changed: 81 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,55 @@
55

66
import argparse
77
import asyncio
8+
import os
89
import sys
9-
import threading
10+
from contextlib import suppress
1011

1112
from sglang_diffusion_routing import DiffusionRouter
1213
from sglang_diffusion_routing.launcher import config as _lcfg
1314

1415

16+
def _print_fired_up_banner(log_prefix: str) -> None:
17+
banner = r"""
18+
19+
____ ____ _ ____ _ __ __ _
20+
/ ___| / ___| | __ _ _ __ __ _| _ \(_)/ _|/ _|_ _ ___(_) ___ _ __
21+
\___ \| | _| | / _` | '_ \ / _` | | | | | |_| |_| | | / __| |/ _ \| '_ \
22+
___) | |_| | |__| (_| | | | | (_| | |_| | | _| _| |_| \__ \ | (_) | | | |
23+
|____/ \____|_____\__,_|_| |_|\__, |____/|_|_| |_| \__,_|___/_|\___/|_| |_|
24+
|___/
25+
26+
____ _ _____ _
27+
| _ \ ___ _ _| |_ ___ ____| ___(_)_ __ ___ __| | _ _ _ __
28+
| |_) / _ \| | | | __/ _ \ __| |_ | | '__/ _ \/ _` | | | | | '_ \
29+
| _ < (_) | |_| | || __/ | | _| | | | | __/ (_| | | |_| | |_) |
30+
|_| \_\___/ \__,_|\__\___|_| |_| |_|_| \___|\__,_| \__,_| .__/
31+
|_|
32+
33+
"""
34+
use_color = sys.stdout.isatty() and "NO_COLOR" not in os.environ
35+
if not use_color:
36+
print(f"{log_prefix} {banner}", flush=True)
37+
return
38+
39+
colors = [
40+
"\033[38;5;45m", # cyan
41+
"\033[38;5;51m", # bright cyan
42+
"\033[38;5;123m", # light cyan
43+
"\033[38;5;159m", # pale blue
44+
]
45+
reset = "\033[0m"
46+
colored_lines = []
47+
for idx, line in enumerate(banner.splitlines()):
48+
if line.strip():
49+
color = colors[idx % len(colors)]
50+
colored_lines.append(f"{color}{line}{reset}")
51+
else:
52+
colored_lines.append(line)
53+
colored_banner = "\n".join(colored_lines)
54+
print(f"{log_prefix} {colored_banner}", flush=True)
55+
56+
1557
def _run_router_server(
1658
args: argparse.Namespace,
1759
router: DiffusionRouter,
@@ -40,15 +82,37 @@ async def _refresh_all_worker_video_support() -> None:
4082

4183
print(f"{log_prefix} starting router on {args.host}:{args.port}", flush=True)
4284
print(
43-
f"{log_prefix} workers: {list(router.worker_request_counts.keys()) or '(none - add via POST /add_worker)'}",
85+
f"{log_prefix} workers: {list(router.worker_request_counts.keys()) or '(none - add via POST /workers)'}",
4486
flush=True,
4587
)
46-
uvicorn.run(
47-
router.app,
88+
config = uvicorn.Config(
89+
app=router.app,
4890
host=args.host,
4991
port=args.port,
5092
log_level=getattr(args, "log_level", "info"),
5193
)
94+
server = uvicorn.Server(config)
95+
96+
async def _serve_with_banner() -> None:
97+
banner_printed = False
98+
99+
async def _wait_and_print_banner() -> None:
100+
nonlocal banner_printed
101+
while not server.started and not server.should_exit:
102+
await asyncio.sleep(0.1)
103+
if server.started and not banner_printed:
104+
banner_printed = True
105+
_print_fired_up_banner(log_prefix)
106+
107+
watcher = asyncio.create_task(_wait_and_print_banner())
108+
try:
109+
await server.serve()
110+
finally:
111+
watcher.cancel()
112+
with suppress(asyncio.CancelledError):
113+
await watcher
114+
115+
asyncio.run(_serve_with_banner())
52116

53117

54118
def _add_router_args(parser: argparse.ArgumentParser) -> None:
@@ -114,16 +178,19 @@ def _handle_router(args: argparse.Namespace) -> int:
114178
launcher_cfg = _lcfg.load_launcher_config(args.launcher_config)
115179
wait_timeout = launcher_cfg.wait_timeout
116180
backend = _lcfg.create_backend(launcher_cfg)
117-
backend.launch()
118-
threading.Thread(
119-
target=backend.wait_ready_and_register,
120-
kwargs=dict(
121-
register_func=router.register_worker,
122-
timeout=wait_timeout,
123-
log_prefix=log_prefix,
124-
),
125-
daemon=True,
126-
).start()
181+
launched_urls = backend.launch()
182+
backend.wait_ready_and_register(
183+
register_func=router.register_worker,
184+
timeout=wait_timeout,
185+
log_prefix=log_prefix,
186+
)
187+
registered_urls = set(router.worker_request_counts.keys())
188+
pending_urls = [u for u in launched_urls if u not in registered_urls]
189+
if pending_urls:
190+
raise RuntimeError(
191+
"managed workers failed to become healthy before router startup: "
192+
+ ", ".join(pending_urls)
193+
)
127194

128195
_run_router_server(args, router=router, log_prefix=log_prefix)
129196
return 0

0 commit comments

Comments
 (0)