Skip to content

Commit 9b46df9

Browse files
Merge pull request zhaochenyang20#14 from zhaochenyang20/validate_video_generation
[router] Validate video generation before slection; Support selctive worker choice
2 parents eef1d43 + 5fab8c0 commit 9b46df9

8 files changed

Lines changed: 128 additions & 154 deletions

File tree

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "sglang"]
2+
path = sglang
3+
url = https://github.com/sgl-project/sglang.git

README.md

Lines changed: 10 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,19 @@ From repository root:
2020
# python3 -m venv .venv
2121
# source .venv/bin/activate
2222
# pip install uv
23-
git clone https://github.com/sglang/sglang-diffusion-routing.git
23+
git clone --recursive https://github.com/sglang/sglang-diffusion-routing.git
2424
cd sglang-diffusion-routing
2525
uv pip install .
2626
```
2727

2828
Workers require SGLang diffusion support:
2929

3030
```bash
31+
# If cloned sglang-diffusion-routing without --recursive, run:
32+
# git submodule update --init --recursive
33+
cd sglang
3134
uv pip install "sglang[diffusion]" --prerelease=allow
35+
cd ..
3236
```
3337

3438
## Quick Start
@@ -114,16 +118,10 @@ with open('output.png', 'wb') as f:
114118
print('Saved to output.png')
115119
"
116120

117-
# Video generation request
118-
curl -X POST http://localhost:30081/generate_video \
119-
-H "Content-Type: application/json" \
120-
-d '{
121-
"model": "Qwen/Qwen-Image",
122-
"prompt": "a flowing river"
123-
}'
124121

125-
# Check per-worker health and load
126-
curl http://localhost:30081/health_workers
122+
curl -X POST http://localhost:30081/update_weights_from_disk \
123+
-H "Content-Type: application/json" \
124+
-d '{"model_path": "Qwen/Qwen-Image-2512"}'
127125
```
128126

129127
### Python requests examples
@@ -177,66 +175,10 @@ print(resp.json())
177175
- `GET /health`: aggregated router health.
178176
- `GET /health_workers`: per-worker health and active request counts.
179177
- `POST /generate`: forwards to worker `/v1/images/generations`.
180-
- `POST /generate_video`: forwards to worker `/v1/videos`.
178+
- `POST /generate_video`: forwards to worker `/v1/videos`; rejects image-only workers (`T2I`/`I2I`/`TI2I`) with `400`.
181179
- `POST /update_weights_from_disk`: broadcast to healthy workers.
182180
- `GET|POST|PUT|DELETE /{path}`: catch-all proxy forwarding.
183-
184-
## `update_weights_from_disk` behavior
185-
186-
Full details: [docs/update_weights_from_disk.md](docs/update_weights_from_disk.md)
187-
188-
- The router forwards request payloads as-is to each healthy worker.
189-
- The router does not validate payload schema; payload semantics are worker-defined.
190-
- Worker servers must implement `POST /update_weights_from_disk`.
191-
192-
Example:
193-
194-
```bash
195-
curl -X POST http://localhost:30081/update_weights_from_disk \
196-
-H "Content-Type: application/json" \
197-
-d '{"model_path": "/path/to/new/weights"}'
198-
```
199-
200-
Response shape:
201-
202-
```json
203-
{
204-
"results": [
205-
{
206-
"worker_url": "http://localhost:30000",
207-
"status_code": 200,
208-
"body": {
209-
"ok": true
210-
}
211-
}
212-
]
213-
}
214-
```
215-
216-
## Benchmark Scripts
217-
218-
Benchmark scripts are available under `tests/benchmarks/diffusion_router/` and are intended for manual runs.
219-
They are not part of default unit test collection (`pytest tests/unit -v`).
220-
221-
Single benchmark:
222-
223-
```bash
224-
SGLANG_USE_MODELSCOPE=TRUE python tests/benchmarks/diffusion_router/bench_router.py \
225-
--model Qwen/Qwen-Image \
226-
--num-workers 2 \
227-
--num-prompts 20 \
228-
--max-concurrency 4
229-
```
230-
231-
Algorithm comparison:
232-
233-
```bash
234-
SGLANG_USE_MODELSCOPE=TRUE python tests/benchmarks/diffusion_router/bench_routing_algorithms.py \
235-
--model Qwen/Qwen-Image \
236-
--num-workers 2 \
237-
--num-prompts 20 \
238-
--max-concurrency 4
239-
```
181+
- `POST /update_weights_from_disk`: broadcast to all healthy workers.
240182

241183
## Project Layout
242184

development.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
Development install:
1+
# Development
2+
3+
## Development Install
24

35
```bash
46
pip install -e .
@@ -10,3 +12,28 @@ Run tests:
1012
pip install pytest
1113
pytest tests/unit -v
1214
```
15+
16+
## Benchmark Scripts
17+
18+
Benchmark scripts are available under `tests/benchmarks/diffusion_router/` and are intended for manual runs.
19+
They are not part of default unit test collection (`pytest tests/unit -v`).
20+
21+
Single benchmark:
22+
23+
```bash
24+
python tests/benchmarks/diffusion_router/bench_router.py \
25+
--model Qwen/Qwen-Image \
26+
--num-workers 2 \
27+
--num-prompts 20 \
28+
--max-concurrency 4
29+
```
30+
31+
Algorithm comparison:
32+
33+
```bash
34+
python tests/benchmarks/diffusion_router/bench_routing_algorithms.py \
35+
--model Qwen/Qwen-Image \
36+
--num-workers 2 \
37+
--num-prompts 20 \
38+
--max-concurrency 4
39+
```

docs/update_weights_from_disk.md

Lines changed: 0 additions & 65 deletions
This file was deleted.

sglang

Submodule sglang added at 45095ba

src/sglang_diffusion_routing/cli/main.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55

66
import argparse
7+
import asyncio
78
import sys
89

910
from sglang_diffusion_routing import DiffusionRouter
@@ -25,8 +26,18 @@ def _run_router_server(
2526
worker_urls if worker_urls is not None else args.worker_urls or []
2627
)
2728
router = DiffusionRouter(args, verbose=args.verbose)
29+
refresh_tasks = []
2830
for url in worker_urls:
29-
router.register_worker(url)
31+
normalized_url = router.normalize_worker_url(url)
32+
router.register_worker(normalized_url)
33+
refresh_tasks.append(router.refresh_worker_video_support(normalized_url))
34+
35+
if refresh_tasks:
36+
37+
async def _refresh_all_worker_video_support() -> None:
38+
await asyncio.gather(*refresh_tasks)
39+
40+
asyncio.run(_refresh_all_worker_video_support())
3041

3142
print(f"{log_prefix} starting router on {args.host}:{args.port}", flush=True)
3243
print(

src/sglang_diffusion_routing/router/diffusion_router.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
logger = logging.getLogger(__name__)
1717

1818
_METADATA_HOSTS = {"169.254.169.254", "metadata.google.internal"}
19+
_IMAGE_TASK_TYPES = {"T2I", "I2I", "TI2I"}
1920

2021

2122
class DiffusionRouter:
23+
2224
def __init__(self, args, verbose: bool = False):
2325
"""Initialize the router for load-balancing sglang-diffusion workers."""
2426
self.args = args
@@ -32,6 +34,9 @@ def __init__(self, args, verbose: bool = False):
3234
self.worker_request_counts: dict[str, int] = {}
3335
# URL -> consecutive health check failures
3436
self.worker_failure_counts: dict[str, int] = {}
37+
# URL -> whether worker supports video generation
38+
# True: supports, False: image-only, None: unknown/unprobed
39+
self.worker_video_support: dict[str, bool | None] = {}
3540
# quarantined workers excluded from routing
3641
self.dead_workers: set[str] = set()
3742
self._health_task: asyncio.Task | None = None
@@ -139,14 +144,23 @@ async def _health_check_loop(self) -> None:
139144
)
140145
await asyncio.sleep(5)
141146

142-
def _use_url(self) -> str:
143-
"""Select a worker URL based on the configured routing algorithm."""
147+
def _select_worker_by_routing(self, worker_urls: list[str] | None = None) -> str:
148+
"""Select a worker URL based on routing algorithm and optional candidates.
149+
150+
Args:
151+
worker_urls: Optional list of worker URLs to consider. If provided,
152+
only these workers will be considered for selection. If not provided,
153+
all registered workers will be considered.
154+
"""
144155
if not self.worker_request_counts:
145156
raise RuntimeError("No workers registered in the pool")
146157

147158
valid_workers = [
148159
w for w in self.worker_request_counts if w not in self.dead_workers
149160
]
161+
if worker_urls is not None:
162+
allowed = {w for w in worker_urls if w in self.worker_request_counts}
163+
valid_workers = [w for w in valid_workers if w in allowed]
150164
if not valid_workers:
151165
raise RuntimeError("No healthy workers available in the pool")
152166

@@ -202,13 +216,14 @@ def _build_proxy_response(
202216
media_type=content_type,
203217
)
204218

205-
async def _forward_to_worker(self, request: Request, path: str) -> Response:
206-
"""Forward a request to a selected worker and return the response."""
219+
async def _forward_to_worker(
220+
self, request: Request, path: str, worker_urls: list[str] | None = None
221+
) -> Response:
222+
"""Forward request to a selected worker (optionally from candidate URLs)."""
207223
try:
208-
worker_url = self._use_url()
224+
worker_url = self._select_worker_by_routing(worker_urls=worker_urls)
209225
except RuntimeError as exc:
210226
return JSONResponse(status_code=503, content={"error": str(exc)})
211-
212227
try:
213228
query = request.url.query
214229
url = (
@@ -243,6 +258,29 @@ async def _forward_to_worker(self, request: Request, path: str) -> Response:
243258
finally:
244259
self._finish_url(worker_url)
245260

261+
async def _probe_worker_video_support(self, worker_url: str) -> bool | None:
262+
"""Probe /v1/models and infer if this worker supports video generation."""
263+
try:
264+
response = await self.client.get(f"{worker_url}/v1/models", timeout=5.0)
265+
if response.status_code == 200:
266+
payload = response.json()
267+
data = payload.get("data")
268+
task_type = (
269+
data[0].get("task_type")
270+
if isinstance(data, list) and data
271+
else None
272+
)
273+
if isinstance(task_type, str):
274+
return task_type.upper() not in _IMAGE_TASK_TYPES
275+
except (httpx.RequestError, json.JSONDecodeError):
276+
return None
277+
278+
async def refresh_worker_video_support(self, worker_url: str) -> None:
279+
"""Refresh cached video capability for a single worker."""
280+
self.worker_video_support[worker_url] = await self._probe_worker_video_support(
281+
worker_url
282+
)
283+
246284
async def _broadcast_to_workers(
247285
self, path: str, body: bytes, headers: dict
248286
) -> list[dict]:
@@ -297,7 +335,7 @@ def _sanitize_response_headers(headers) -> dict:
297335
}
298336

299337
@staticmethod
300-
def _normalize_worker_url(url: str) -> str:
338+
def normalize_worker_url(url: str) -> str:
301339
if not isinstance(url, str):
302340
raise ValueError("worker_url must be a string")
303341

@@ -345,7 +383,22 @@ async def generate(self, request: Request):
345383

346384
async def generate_video(self, request: Request):
347385
"""Route video generation to /v1/videos."""
348-
return await self._forward_to_worker(request, "v1/videos")
386+
candidate_workers = [
387+
worker_url
388+
for worker_url, support in self.worker_video_support.items()
389+
if support
390+
]
391+
392+
if not candidate_workers:
393+
return JSONResponse(
394+
status_code=400,
395+
content={
396+
"error": "No video-capable workers available in current worker pool.",
397+
},
398+
)
399+
return await self._forward_to_worker(
400+
request, "v1/videos", worker_urls=candidate_workers
401+
)
349402

350403
async def health(self, request: Request):
351404
"""Aggregated health status: healthy if at least one worker is alive."""
@@ -388,10 +441,11 @@ async def update_weights_from_disk(self, request: Request):
388441

389442
def register_worker(self, url: str) -> None:
390443
"""Register a worker URL if not already known."""
391-
normalized_url = self._normalize_worker_url(url)
444+
normalized_url = self.normalize_worker_url(url)
392445
if normalized_url not in self.worker_request_counts:
393446
self.worker_request_counts[normalized_url] = 0
394447
self.worker_failure_counts[normalized_url] = 0
448+
self.worker_video_support[normalized_url] = None
395449
if self.verbose:
396450
print(f"[diffusion-router] Added new worker: {normalized_url}")
397451

@@ -422,6 +476,7 @@ async def add_worker(self, request: Request):
422476
self.register_worker(worker_url)
423477
except ValueError as exc:
424478
return JSONResponse(status_code=400, content={"error": str(exc)})
479+
await self.refresh_worker_video_support(worker_url)
425480
return {
426481
"status": "success",
427482
"worker_urls": list(self.worker_request_counts.keys()),

0 commit comments

Comments
 (0)