diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..377caea --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "sglang"] + path = sglang + url = https://github.com/sgl-project/sglang.git diff --git a/README.md b/README.md index 5bbe48a..12025a3 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ From repository root: # python3 -m venv .venv # source .venv/bin/activate # pip install uv -git clone https://github.com/sglang/sglang-diffusion-routing.git +git clone --recursive https://github.com/sglang/sglang-diffusion-routing.git cd sglang-diffusion-routing uv pip install . ``` @@ -28,7 +28,11 @@ uv pip install . Workers require SGLang diffusion support: ```bash +# If cloned sglang-diffusion-routing without --recursive, run: +# git submodule update --init --recursive +cd sglang uv pip install "sglang[diffusion]" --prerelease=allow +cd .. ``` ## Quick Start @@ -113,16 +117,10 @@ with open('output.png', 'wb') as f: print('Saved to output.png') " -# Video generation request -curl -X POST http://localhost:30081/generate_video \ - -H "Content-Type: application/json" \ - -d '{ - "model": "Qwen/Qwen-Image", - "prompt": "a flowing river" - }' -# Check per-worker health and load -curl http://localhost:30081/health_workers +curl -X POST http://localhost:30081/update_weights_from_disk \ + -H "Content-Type: application/json" \ + -d '{"model_path": "Qwen/Qwen-Image-2512"}' ``` ## Router API @@ -132,66 +130,10 @@ curl http://localhost:30081/health_workers - `GET /health`: aggregated router health. - `GET /health_workers`: per-worker health and active request counts. - `POST /generate`: forwards to worker `/v1/images/generations`. -- `POST /generate_video`: forwards to worker `/v1/videos`. +- `POST /generate_video`: forwards to worker `/v1/videos`; rejects image-only workers (`T2I`/`I2I`/`TI2I`) with `400`. - `POST /update_weights_from_disk`: broadcast to healthy workers. - `GET|POST|PUT|DELETE /{path}`: catch-all proxy forwarding. - -## `update_weights_from_disk` behavior - -Full details: [docs/update_weights_from_disk.md](docs/update_weights_from_disk.md) - -- The router forwards request payloads as-is to each healthy worker. -- The router does not validate payload schema; payload semantics are worker-defined. -- Worker servers must implement `POST /update_weights_from_disk`. - -Example: - -```bash -curl -X POST http://localhost:30081/update_weights_from_disk \ - -H "Content-Type: application/json" \ - -d '{"model_path": "/path/to/new/weights"}' -``` - -Response shape: - -```json -{ - "results": [ - { - "worker_url": "http://localhost:30000", - "status_code": 200, - "body": { - "ok": true - } - } - ] -} -``` - -## Benchmark Scripts - -Benchmark scripts are available under `tests/benchmarks/diffusion_router/` and are intended for manual runs. -They are not part of default unit test collection (`pytest tests/unit -v`). - -Single benchmark: - -```bash -SGLANG_USE_MODELSCOPE=TRUE python tests/benchmarks/diffusion_router/bench_router.py \ - --model Qwen/Qwen-Image \ - --num-workers 2 \ - --num-prompts 20 \ - --max-concurrency 4 -``` - -Algorithm comparison: - -```bash -SGLANG_USE_MODELSCOPE=TRUE python tests/benchmarks/diffusion_router/bench_routing_algorithms.py \ - --model Qwen/Qwen-Image \ - --num-workers 2 \ - --num-prompts 20 \ - --max-concurrency 4 -``` +- `POST /update_weights_from_disk`: broadcast to all healthy workers. ## Project Layout diff --git a/development.md b/development.md index 7c69da8..bc1ca23 100644 --- a/development.md +++ b/development.md @@ -1,4 +1,6 @@ -Development install: +# Development + +## Development Install ```bash pip install -e . @@ -10,3 +12,28 @@ Run tests: pip install pytest pytest tests/unit -v ``` + +## Benchmark Scripts + +Benchmark scripts are available under `tests/benchmarks/diffusion_router/` and are intended for manual runs. +They are not part of default unit test collection (`pytest tests/unit -v`). + +Single benchmark: + +```bash +python tests/benchmarks/diffusion_router/bench_router.py \ + --model Qwen/Qwen-Image \ + --num-workers 2 \ + --num-prompts 20 \ + --max-concurrency 4 +``` + +Algorithm comparison: + +```bash +python tests/benchmarks/diffusion_router/bench_routing_algorithms.py \ + --model Qwen/Qwen-Image \ + --num-workers 2 \ + --num-prompts 20 \ + --max-concurrency 4 +``` diff --git a/docs/update_weights_from_disk.md b/docs/update_weights_from_disk.md deleted file mode 100644 index c6362b5..0000000 --- a/docs/update_weights_from_disk.md +++ /dev/null @@ -1,65 +0,0 @@ -# update_weights_from_disk - -This document describes `POST /update_weights_from_disk` behavior in this repository. - -## Router behavior - -The router does not validate or transform payload fields. -It forwards the original request body to every healthy worker and returns per-worker results. - -Payload semantics are therefore defined by the worker implementation, not by the router. - -## Requirements - -- Worker servers must implement `POST /update_weights_from_disk`. -- For SGLang workers, use a version that includes this endpoint. -- Weights must match your worker runtime expectations. - -## Basic example - -```bash -curl -X POST http://localhost:30080/update_weights_from_disk \ - -H "Content-Type: application/json" \ - -d '{"model_path": "/path/to/new/weights"}' -``` - -## Optional fields - -Some worker versions support optional fields such as `target_modules`: - -```bash -curl -X POST http://localhost:30080/update_weights_from_disk \ - -H "Content-Type: application/json" \ - -d '{"model_path": "/path/to/weights", "target_modules": ["transformer", "vae"]}' -``` - -If your worker version does not support extra fields, failure is returned by the worker side. - -## Response shape - -The router response includes one item per healthy worker: - -```json -{ - "results": [ - { - "worker_url": "http://localhost:10090", - "status_code": 200, - "body": { - "ok": true - } - }, - { - "worker_url": "http://localhost:10092", - "status_code": 500, - "body": { - "error": "worker-side failure" - } - } - ] -} -``` - -Notes: -- Quarantined workers are excluded from broadcast. -- Transport/runtime exceptions are surfaced as per-worker `status_code=502`. diff --git a/sglang b/sglang new file mode 160000 index 0000000..45095ba --- /dev/null +++ b/sglang @@ -0,0 +1 @@ +Subproject commit 45095bac70ef1382425cb86f4b7af66dc6e7641c diff --git a/src/sglang_diffusion_routing/cli/main.py b/src/sglang_diffusion_routing/cli/main.py index 3910826..5d1336a 100644 --- a/src/sglang_diffusion_routing/cli/main.py +++ b/src/sglang_diffusion_routing/cli/main.py @@ -4,6 +4,7 @@ from __future__ import annotations import argparse +import asyncio import sys from sglang_diffusion_routing import DiffusionRouter @@ -25,8 +26,18 @@ def _run_router_server( worker_urls if worker_urls is not None else args.worker_urls or [] ) router = DiffusionRouter(args, verbose=args.verbose) + refresh_tasks = [] for url in worker_urls: - router.register_worker(url) + normalized_url = router.normalize_worker_url(url) + router.register_worker(normalized_url) + refresh_tasks.append(router.refresh_worker_video_support(normalized_url)) + + if refresh_tasks: + + async def _refresh_all_worker_video_support() -> None: + await asyncio.gather(*refresh_tasks) + + asyncio.run(_refresh_all_worker_video_support()) print(f"{log_prefix} starting router on {args.host}:{args.port}", flush=True) print( diff --git a/src/sglang_diffusion_routing/router/diffusion_router.py b/src/sglang_diffusion_routing/router/diffusion_router.py index c2b9c87..bd0307b 100644 --- a/src/sglang_diffusion_routing/router/diffusion_router.py +++ b/src/sglang_diffusion_routing/router/diffusion_router.py @@ -16,9 +16,11 @@ logger = logging.getLogger(__name__) _METADATA_HOSTS = {"169.254.169.254", "metadata.google.internal"} +_IMAGE_TASK_TYPES = {"T2I", "I2I", "TI2I"} class DiffusionRouter: + def __init__(self, args, verbose: bool = False): """Initialize the router for load-balancing sglang-diffusion workers.""" self.args = args @@ -32,6 +34,9 @@ def __init__(self, args, verbose: bool = False): self.worker_request_counts: dict[str, int] = {} # URL -> consecutive health check failures self.worker_failure_counts: dict[str, int] = {} + # URL -> whether worker supports video generation + # True: supports, False: image-only, None: unknown/unprobed + self.worker_video_support: dict[str, bool | None] = {} # quarantined workers excluded from routing self.dead_workers: set[str] = set() self._health_task: asyncio.Task | None = None @@ -139,14 +144,23 @@ async def _health_check_loop(self) -> None: ) await asyncio.sleep(5) - def _use_url(self) -> str: - """Select a worker URL based on the configured routing algorithm.""" + def _select_worker_by_routing(self, worker_urls: list[str] | None = None) -> str: + """Select a worker URL based on routing algorithm and optional candidates. + + Args: + worker_urls: Optional list of worker URLs to consider. If provided, + only these workers will be considered for selection. If not provided, + all registered workers will be considered. + """ if not self.worker_request_counts: raise RuntimeError("No workers registered in the pool") valid_workers = [ w for w in self.worker_request_counts if w not in self.dead_workers ] + if worker_urls is not None: + allowed = {w for w in worker_urls if w in self.worker_request_counts} + valid_workers = [w for w in valid_workers if w in allowed] if not valid_workers: raise RuntimeError("No healthy workers available in the pool") @@ -202,13 +216,14 @@ def _build_proxy_response( media_type=content_type, ) - async def _forward_to_worker(self, request: Request, path: str) -> Response: - """Forward a request to a selected worker and return the response.""" + async def _forward_to_worker( + self, request: Request, path: str, worker_urls: list[str] | None = None + ) -> Response: + """Forward request to a selected worker (optionally from candidate URLs).""" try: - worker_url = self._use_url() + worker_url = self._select_worker_by_routing(worker_urls=worker_urls) except RuntimeError as exc: return JSONResponse(status_code=503, content={"error": str(exc)}) - try: query = request.url.query url = ( @@ -243,6 +258,29 @@ async def _forward_to_worker(self, request: Request, path: str) -> Response: finally: self._finish_url(worker_url) + async def _probe_worker_video_support(self, worker_url: str) -> bool | None: + """Probe /v1/models and infer if this worker supports video generation.""" + try: + response = await self.client.get(f"{worker_url}/v1/models", timeout=5.0) + if response.status_code == 200: + payload = response.json() + data = payload.get("data") + task_type = ( + data[0].get("task_type") + if isinstance(data, list) and data + else None + ) + if isinstance(task_type, str): + return task_type.upper() not in _IMAGE_TASK_TYPES + except (httpx.RequestError, json.JSONDecodeError): + return None + + async def refresh_worker_video_support(self, worker_url: str) -> None: + """Refresh cached video capability for a single worker.""" + self.worker_video_support[worker_url] = await self._probe_worker_video_support( + worker_url + ) + async def _broadcast_to_workers( self, path: str, body: bytes, headers: dict ) -> list[dict]: @@ -297,7 +335,7 @@ def _sanitize_response_headers(headers) -> dict: } @staticmethod - def _normalize_worker_url(url: str) -> str: + def normalize_worker_url(url: str) -> str: if not isinstance(url, str): raise ValueError("worker_url must be a string") @@ -345,7 +383,22 @@ async def generate(self, request: Request): async def generate_video(self, request: Request): """Route video generation to /v1/videos.""" - return await self._forward_to_worker(request, "v1/videos") + candidate_workers = [ + worker_url + for worker_url, support in self.worker_video_support.items() + if support + ] + + if not candidate_workers: + return JSONResponse( + status_code=400, + content={ + "error": "No video-capable workers available in current worker pool.", + }, + ) + return await self._forward_to_worker( + request, "v1/videos", worker_urls=candidate_workers + ) async def health(self, request: Request): """Aggregated health status: healthy if at least one worker is alive.""" @@ -388,10 +441,11 @@ async def update_weights_from_disk(self, request: Request): def register_worker(self, url: str) -> None: """Register a worker URL if not already known.""" - normalized_url = self._normalize_worker_url(url) + normalized_url = self.normalize_worker_url(url) if normalized_url not in self.worker_request_counts: self.worker_request_counts[normalized_url] = 0 self.worker_failure_counts[normalized_url] = 0 + self.worker_video_support[normalized_url] = None if self.verbose: print(f"[diffusion-router] Added new worker: {normalized_url}") @@ -422,6 +476,7 @@ async def add_worker(self, request: Request): self.register_worker(worker_url) except ValueError as exc: return JSONResponse(status_code=400, content={"error": str(exc)}) + await self.refresh_worker_video_support(worker_url) return { "status": "success", "worker_urls": list(self.worker_request_counts.keys()), diff --git a/tests/unit/test_diffusion_router.py b/tests/unit/test_diffusion_router.py index 223835d..3574fe0 100644 --- a/tests/unit/test_diffusion_router.py +++ b/tests/unit/test_diffusion_router.py @@ -52,7 +52,7 @@ def test_selects_min_load(self, router_factory): router = router_factory( {"http://w1:8000": 5, "http://w2:8000": 2, "http://w3:8000": 8} ) - selected = router._use_url() + selected = router._select_worker_by_routing() assert selected == "http://w2:8000" assert router.worker_request_counts["http://w2:8000"] == 3 @@ -61,7 +61,7 @@ def test_excludes_dead_workers(self, router_factory): {"http://w1:8000": 5, "http://w2:8000": 2, "http://w3:8000": 8}, dead={"http://w2:8000"}, ) - selected = router._use_url() + selected = router._select_worker_by_routing() assert selected == "http://w1:8000" assert router.worker_request_counts["http://w1:8000"] == 6 @@ -74,7 +74,7 @@ def test_cycles_workers(self, router_factory): {"http://w1:8000": 0, "http://w2:8000": 0, "http://w3:8000": 0}, routing_algorithm="round-robin", ) - results = [router._use_url() for _ in range(6)] + results = [router._select_worker_by_routing() for _ in range(6)] workers = list(router.worker_request_counts.keys()) expected = [workers[i % 3] for i in range(6)] assert results == expected @@ -87,7 +87,7 @@ def test_excludes_dead_workers(self, router_factory): dead={"http://w2:8000"}, routing_algorithm="round-robin", ) - results = [router._use_url() for _ in range(4)] + results = [router._select_worker_by_routing() for _ in range(4)] assert "http://w2:8000" not in results assert all(url in ("http://w1:8000", "http://w3:8000") for url in results) @@ -105,7 +105,7 @@ def test_selects_from_valid_workers(self, router_factory): # Reset counts so they do not grow unbounded for url in router.worker_request_counts: router.worker_request_counts[url] = 0 - seen.add(router._use_url()) + seen.add(router._select_worker_by_routing()) assert seen == {"http://w1:8000", "http://w2:8000", "http://w3:8000"} def test_excludes_dead_workers(self, router_factory): @@ -115,7 +115,7 @@ def test_excludes_dead_workers(self, router_factory): routing_algorithm="random", ) for _ in range(20): - url = router._use_url() + url = router._select_worker_by_routing() assert url != "http://w2:8000" router.worker_request_counts[url] -= 1 # reset increment @@ -127,7 +127,7 @@ class TestErrorCases: def test_raises_when_no_workers(self, router_factory, algorithm): router = router_factory({}, routing_algorithm=algorithm) with pytest.raises(RuntimeError, match="No workers registered"): - router._use_url() + router._select_worker_by_routing() @pytest.mark.parametrize("algorithm", ["least-request", "round-robin", "random"]) def test_raises_when_all_dead(self, router_factory, algorithm): @@ -137,16 +137,16 @@ def test_raises_when_all_dead(self, router_factory, algorithm): routing_algorithm=algorithm, ) with pytest.raises(RuntimeError, match="No healthy workers"): - router._use_url() + router._select_worker_by_routing() class TestCountManagement: - """Test that _use_url / _finish_url correctly track active request counts.""" + """Test that _select_worker_by_routing / _finish_url correctly track active request counts.""" @pytest.mark.parametrize("algorithm", ["least-request", "round-robin", "random"]) def test_increment_and_finish(self, router_factory, algorithm): router = router_factory({"http://w1:8000": 0}, routing_algorithm=algorithm) - url = router._use_url() + url = router._select_worker_by_routing() assert router.worker_request_counts[url] == 1 router._finish_url(url) assert router.worker_request_counts[url] == 0