Skip to content
14 changes: 3 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,6 @@ with open('output.png', 'wb') as f:
print('Saved to output.png')
"

# Video generation request
curl -X POST http://localhost:30081/generate_video \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen-Image",
"prompt": "a flowing river"
}'

# Check per-worker health and load
curl http://localhost:30081/health_workers
```
Expand All @@ -132,7 +124,7 @@ curl http://localhost:30081/health_workers
- `GET /health`: aggregated router health.
- `GET /health_workers`: per-worker health and active request counts.
- `POST /generate`: forwards to worker `/v1/images/generations`.
- `POST /generate_video`: forwards to worker `/v1/videos`.
- `POST /generate_video`: forwards to worker `/v1/videos`; rejects image-only workers (`T2I`/`I2I`/`TI2I`) with `400`.
- `POST /update_weights_from_disk`: broadcast to healthy workers.
- `GET|POST|PUT|DELETE /{path}`: catch-all proxy forwarding.

Expand Down Expand Up @@ -176,7 +168,7 @@ They are not part of default unit test collection (`pytest tests/unit -v`).
Single benchmark:

```bash
SGLANG_USE_MODELSCOPE=TRUE python tests/benchmarks/diffusion_router/bench_router.py \
python tests/benchmarks/diffusion_router/bench_router.py \
--model Qwen/Qwen-Image \
--num-workers 2 \
--num-prompts 20 \
Expand All @@ -186,7 +178,7 @@ SGLANG_USE_MODELSCOPE=TRUE python tests/benchmarks/diffusion_router/bench_router
Algorithm comparison:

```bash
SGLANG_USE_MODELSCOPE=TRUE python tests/benchmarks/diffusion_router/bench_routing_algorithms.py \
python tests/benchmarks/diffusion_router/bench_routing_algorithms.py \
--model Qwen/Qwen-Image \
--num-workers 2 \
--num-prompts 20 \
Expand Down
1 change: 1 addition & 0 deletions sglang
Submodule sglang added at 45095b
2 changes: 2 additions & 0 deletions src/sglang_diffusion_routing/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import argparse
import asyncio
import sys

from sglang_diffusion_routing import DiffusionRouter
Expand All @@ -27,6 +28,7 @@ def _run_router_server(
router = DiffusionRouter(args, verbose=args.verbose)
for url in worker_urls:
router.register_worker(url)
asyncio.run(router._refresh_worker_video_support(url))
Comment thread
zhaochenyang20 marked this conversation as resolved.
Outdated

print(f"{log_prefix} starting router on {args.host}:{args.port}", flush=True)
print(
Expand Down
72 changes: 64 additions & 8 deletions src/sglang_diffusion_routing/router/diffusion_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
logger = logging.getLogger(__name__)

_METADATA_HOSTS = {"169.254.169.254", "metadata.google.internal"}
_IMAGE_TASK_TYPES = {"T2I", "I2I", "TI2I"}


class DiffusionRouter:

def __init__(self, args, verbose: bool = False):
"""Initialize the router for load-balancing sglang-diffusion workers."""
self.args = args
Expand All @@ -32,6 +34,9 @@ def __init__(self, args, verbose: bool = False):
self.worker_request_counts: dict[str, int] = {}
# URL -> consecutive health check failures
self.worker_failure_counts: dict[str, int] = {}
# URL -> whether worker supports video generation
# True: supports, False: image-only, None: unknown/unprobed
self.worker_video_support: dict[str, bool | None] = {}
# quarantined workers excluded from routing
self.dead_workers: set[str] = set()
self._health_task: asyncio.Task | None = None
Expand Down Expand Up @@ -139,14 +144,23 @@ async def _health_check_loop(self) -> None:
)
await asyncio.sleep(5)

def _use_url(self) -> str:
Comment thread
zhaochenyang20 marked this conversation as resolved.
"""Select a worker URL based on the configured routing algorithm."""
def _select_worker_by_routing(self, worker_urls: list[str] | None = None) -> str:
"""Select a worker URL based on routing algorithm and optional candidates.

Args:
worker_urls: Optional list of worker URLs to consider. If provided,
only these workers will be considered for selection. If not provided,
all registered workers will be considered.
"""
if not self.worker_request_counts:
raise RuntimeError("No workers registered in the pool")

valid_workers = [
w for w in self.worker_request_counts if w not in self.dead_workers
]
if worker_urls is not None:
allowed = {w for w in worker_urls if w in self.worker_request_counts}
valid_workers = [w for w in valid_workers if w in allowed]
if not valid_workers:
raise RuntimeError("No healthy workers available in the pool")

Expand Down Expand Up @@ -202,13 +216,14 @@ def _build_proxy_response(
media_type=content_type,
)

async def _forward_to_worker(self, request: Request, path: str) -> Response:
"""Forward a request to a selected worker and return the response."""
async def _forward_to_worker(
self, request: Request, path: str, worker_urls: list[str] | None = None
) -> Response:
"""Forward request to a selected worker (optionally from candidate URLs)."""
try:
worker_url = self._use_url()
worker_url = self._select_worker_by_routing(worker_urls=worker_urls)
except RuntimeError as exc:
return JSONResponse(status_code=503, content={"error": str(exc)})

try:
query = request.url.query
url = (
Expand Down Expand Up @@ -243,6 +258,29 @@ async def _forward_to_worker(self, request: Request, path: str) -> Response:
finally:
self._finish_url(worker_url)

async def _probe_worker_video_support(self, worker_url: str) -> bool | None:
"""Probe /v1/models and infer if this worker supports video generation."""
try:
response = await self.client.get(f"{worker_url}/v1/models", timeout=5.0)
Comment thread
zhaochenyang20 marked this conversation as resolved.
if response.status_code == 200:
payload = response.json()
data = payload.get("data")
task_type = (
data[0].get("task_type")
if isinstance(data, list) and data
else None
)
if isinstance(task_type, str):
return task_type.upper() not in self._IMAGE_TASK_TYPES
Comment thread
zhaochenyang20 marked this conversation as resolved.
Outdated
except Exception:
return None
Comment thread
zhaochenyang20 marked this conversation as resolved.
Outdated

async def _refresh_worker_video_support(self, worker_url: str) -> None:
"""Refresh cached video capability for a single worker."""
self.worker_video_support[worker_url] = await self._probe_worker_video_support(
worker_url
)

async def _broadcast_to_workers(
self, path: str, body: bytes, headers: dict
) -> list[dict]:
Expand Down Expand Up @@ -345,7 +383,22 @@ async def generate(self, request: Request):

async def generate_video(self, request: Request):
"""Route video generation to /v1/videos."""
return await self._forward_to_worker(request, "v1/videos")
candidate_workers = [
worker_url
for worker_url, support in self.worker_video_support.items()
if support is True
Comment thread
zhaochenyang20 marked this conversation as resolved.
Outdated
]

if not candidate_workers:
return JSONResponse(
status_code=400,
content={
"error": "No video-capable workers available in current worker pool.",
},
)
return await self._forward_to_worker(
request, "v1/videos", worker_urls=candidate_workers
)

async def health(self, request: Request):
"""Aggregated health status: healthy if at least one worker is alive."""
Expand Down Expand Up @@ -392,6 +445,7 @@ def register_worker(self, url: str) -> None:
if normalized_url not in self.worker_request_counts:
self.worker_request_counts[normalized_url] = 0
self.worker_failure_counts[normalized_url] = 0
self.worker_video_support[normalized_url] = None
if self.verbose:
print(f"[diffusion-router] Added new worker: {normalized_url}")

Expand Down Expand Up @@ -419,9 +473,11 @@ async def add_worker(self, request: Request):
)

try:
self.register_worker(worker_url)
normalized_url = self._normalize_worker_url(worker_url)
self.register_worker(normalized_url)
Comment thread
zhaochenyang20 marked this conversation as resolved.
Outdated
except ValueError as exc:
return JSONResponse(status_code=400, content={"error": str(exc)})
await self._refresh_worker_video_support(normalized_url)
Comment thread
zhaochenyang20 marked this conversation as resolved.
Outdated
return {
"status": "success",
"worker_urls": list(self.worker_request_counts.keys()),
Expand Down
Loading