diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml deleted file mode 100644 index 48d62d9..0000000 --- a/.github/workflows/pr-test.yml +++ /dev/null @@ -1,25 +0,0 @@ -name: PR Test - -on: - pull_request: - branches: [main] - -jobs: - pr-test: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: "3.12" - - - name: Install package and test deps - run: | - python -m pip install --upgrade pip - python -m pip install -e . - python -m pip install pytest - - - name: Run CPU-only tests - run: pytest tests/unit -v diff --git a/development.md b/development.md index 14cada0..bc1ca23 100644 --- a/development.md +++ b/development.md @@ -10,11 +10,7 @@ Run tests: ```bash pip install pytest -# CPU-only tests (unit + fake e2e) pytest tests/unit -v - -# Real E2E tests (GPU required, longer runtime) -pytest tests/e2e/test_e2e_sglang.py -v -s ``` ## Benchmark Scripts diff --git a/pyproject.toml b/pyproject.toml index 581f055..107ac21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,4 +33,4 @@ package-dir = { "" = "src" } where = ["src"] [tool.pytest.ini_options] -testpaths = ["tests/unit", "tests/e2e"] +testpaths = ["tests/unit"] diff --git a/src/sglang_diffusion_routing/cli/main.py b/src/sglang_diffusion_routing/cli/main.py index e260aab..5d1336a 100644 --- a/src/sglang_diffusion_routing/cli/main.py +++ b/src/sglang_diffusion_routing/cli/main.py @@ -4,6 +4,7 @@ from __future__ import annotations import argparse +import asyncio import sys from sglang_diffusion_routing import DiffusionRouter @@ -25,9 +26,18 @@ def _run_router_server( worker_urls if worker_urls is not None else args.worker_urls or [] ) router = DiffusionRouter(args, verbose=args.verbose) + refresh_tasks = [] for url in worker_urls: normalized_url = router.normalize_worker_url(url) router.register_worker(normalized_url) + refresh_tasks.append(router.refresh_worker_video_support(normalized_url)) + + if refresh_tasks: + + async def _refresh_all_worker_video_support() -> None: + await asyncio.gather(*refresh_tasks) + + asyncio.run(_refresh_all_worker_video_support()) print(f"{log_prefix} starting router on {args.host}:{args.port}", flush=True) print( diff --git a/src/sglang_diffusion_routing/router/diffusion_router.py b/src/sglang_diffusion_routing/router/diffusion_router.py index d301f13..bd0307b 100644 --- a/src/sglang_diffusion_routing/router/diffusion_router.py +++ b/src/sglang_diffusion_routing/router/diffusion_router.py @@ -69,16 +69,6 @@ def _setup_routes(self) -> None: ) async def _start_background_health_check(self) -> None: - # Probe video capability for pre-registered workers in the running event loop. - unknown_workers = [ - url for url, support in self.worker_video_support.items() if support is None - ] - if unknown_workers: - await asyncio.gather( - *(self.refresh_worker_video_support(url) for url in unknown_workers), - return_exceptions=True, - ) - if self._health_task is None or self._health_task.done(): self._health_task = asyncio.create_task(self._health_check_loop()) @@ -393,12 +383,6 @@ async def generate(self, request: Request): async def generate_video(self, request: Request): """Route video generation to /v1/videos.""" - if not self.worker_request_counts: - return JSONResponse( - status_code=503, - content={"error": "No workers registered in the pool"}, - ) - candidate_workers = [ worker_url for worker_url, support in self.worker_video_support.items() diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 09f8c8b..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Pytest configuration: force local src import precedence.""" - -from __future__ import annotations - -import sys -from pathlib import Path - -src_str = str(Path(__file__).resolve().parent.parent / "src") -while src_str in sys.path: - sys.path.remove(src_str) -sys.path.insert(0, src_str) diff --git a/tests/e2e/__init__.py b/tests/e2e/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/e2e/test_e2e_sglang.py b/tests/e2e/test_e2e_sglang.py deleted file mode 100644 index 5cecf1a..0000000 --- a/tests/e2e/test_e2e_sglang.py +++ /dev/null @@ -1,399 +0,0 @@ -""" -End-to-end tests with real sglang diffusion workers. - -Requires: - - sglang installed with diffusion support: pip install "sglang[diffusion]" - - At least 1 GPU available - - Model weights accessible (downloads on first run) - -These tests are SKIPPED automatically when sglang or GPU is not available. -To run explicitly: - - pytest tests/e2e/test_e2e_sglang.py -v -s - -Override model/GPU config via environment variables: - SGLANG_TEST_MODEL Model path (default: Qwen/Qwen-Image) - SGLANG_TEST_NUM_GPUS GPUs per worker (default: 1) - SGLANG_TEST_NUM_WORKERS Number of workers (default: 2) - SGLANG_TEST_TIMEOUT Startup timeout in seconds (default: 600) -""" - -from __future__ import annotations - -import base64 -import os -import shutil -import signal -import socket -import subprocess -import sys -import time -from pathlib import Path - -import httpx -import pytest - -REPO_ROOT = Path(__file__).resolve().parents[2] -PYTHON = sys.executable - -DEFAULT_MODEL = "Qwen/Qwen-Image" -DEFAULT_NUM_GPUS = 1 -DEFAULT_NUM_WORKERS = 2 -DEFAULT_TIMEOUT = 600 # sglang model loading can be slow - - -def _has_sglang() -> bool: - try: - import sglang # noqa: F401 - - return True - except ImportError: - return False - - -def _gpu_count() -> int: - try: - import torch - - return torch.cuda.device_count() - except Exception: - return 0 - - -def _get_env(key: str, default: str) -> str: - return os.environ.get(key, default) - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -def _env() -> dict[str, str]: - env = os.environ.copy() - src = str(REPO_ROOT / "src") - old = env.get("PYTHONPATH", "") - env["PYTHONPATH"] = f"{src}:{old}" if old else src - return env - - -def _wait_healthy( - url: str, timeout: float, label: str = "", proc: subprocess.Popen | None = None -) -> None: - deadline = time.monotonic() + timeout - last_log = 0.0 - while time.monotonic() < deadline: - if proc is not None and proc.poll() is not None: - raise RuntimeError( - f"{label} process exited with code {proc.returncode} during startup" - ) - try: - r = httpx.get(f"{url}/health", timeout=5.0) - if r.status_code == 200: - return - except httpx.HTTPError: - pass - now = time.monotonic() - if now - last_log >= 30: - elapsed = now - (deadline - timeout) - print(f" Waiting for {label}... ({elapsed:.0f}s)", flush=True) - last_log = now - time.sleep(2) - raise TimeoutError(f"{label} at {url} not healthy within {timeout}s") - - -def _kill_proc(proc: subprocess.Popen) -> None: - if proc.poll() is not None: - return - try: - os.killpg(proc.pid, signal.SIGTERM) - except (ProcessLookupError, PermissionError): - pass - try: - proc.wait(timeout=15) - except subprocess.TimeoutExpired: - try: - os.killpg(proc.pid, signal.SIGKILL) - except (ProcessLookupError, PermissionError): - pass - proc.wait(timeout=5) - - -# -- Skip conditions ------------------------------------------------------- - -_skip_no_sglang = pytest.mark.skipif( - not _has_sglang() or shutil.which("sglang") is None, - reason="sglang not installed or 'sglang' CLI not in PATH", -) -_skip_no_gpu = pytest.mark.skipif( - _gpu_count() == 0, - reason="No GPU available", -) - -pytestmark = [_skip_no_sglang, _skip_no_gpu] - - -# -- Fixtures --------------------------------------------------------------- - - -class SglangWorker: - def __init__(self, proc: subprocess.Popen, url: str): - self.proc = proc - self.url = url - - -@pytest.fixture(scope="module") -def sglang_config(): - model = _get_env("SGLANG_TEST_MODEL", DEFAULT_MODEL) - num_gpus = int(_get_env("SGLANG_TEST_NUM_GPUS", str(DEFAULT_NUM_GPUS))) - num_workers = int(_get_env("SGLANG_TEST_NUM_WORKERS", str(DEFAULT_NUM_WORKERS))) - timeout = int(_get_env("SGLANG_TEST_TIMEOUT", str(DEFAULT_TIMEOUT))) - - gpus_available = _gpu_count() - needed = num_workers * num_gpus - if gpus_available < needed: - pytest.skip( - f"Need {needed} GPUs ({num_workers} workers x {num_gpus} GPUs), " - f"only {gpus_available} available" - ) - - return { - "model": model, - "num_gpus": num_gpus, - "num_workers": num_workers, - "timeout": timeout, - } - - -@pytest.fixture(scope="module") -def sglang_workers(sglang_config): - """Launch real sglang diffusion worker processes.""" - workers = [] - procs = [] - env = _env() - gpu_pool = list(range(_gpu_count())) - - for i in range(sglang_config["num_workers"]): - port = _find_free_port() - gpu_start = i * sglang_config["num_gpus"] - gpu_end = gpu_start + sglang_config["num_gpus"] - gpu_ids = ",".join(str(gpu_pool[g]) for g in range(gpu_start, gpu_end)) - - worker_env = dict(env) - worker_env["CUDA_VISIBLE_DEVICES"] = gpu_ids - - cmd = [ - "sglang", - "serve", - "--model-path", - sglang_config["model"], - "--num-gpus", - str(sglang_config["num_gpus"]), - "--host", - "127.0.0.1", - "--port", - str(port), - ] - - print( - f"\n[sglang-test] Starting worker {i} on port {port} (GPU: {gpu_ids})", - flush=True, - ) - proc = subprocess.Popen( - cmd, - env=worker_env, - start_new_session=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - procs.append(proc) - workers.append(SglangWorker(proc, f"http://127.0.0.1:{port}")) - - try: - for w in workers: - _wait_healthy( - w.url, - sglang_config["timeout"], - label=f"sglang worker {w.url}", - proc=w.proc, - ) - except (RuntimeError, TimeoutError) as exc: - # Collect stderr from the first failed worker for diagnostics - stderr_snippet = "" - for p in procs: - if p.poll() is not None and p.stderr: - try: - stderr_snippet = p.stderr.read(2048).decode( - "utf-8", errors="replace" - ) - except Exception: - pass - if stderr_snippet: - break - for p in procs: - _kill_proc(p) - pytest.skip( - f"sglang worker failed to start: {exc}" - + (f"\nstderr: {stderr_snippet[:500]}" if stderr_snippet else "") - ) - - yield workers - - for p in procs: - _kill_proc(p) - - -@pytest.fixture(scope="module") -def router_url(sglang_workers): - """Launch a real router connected to real sglang workers.""" - port = _find_free_port() - worker_urls = [w.url for w in sglang_workers] - cmd = [ - PYTHON, - "-m", - "sglang_diffusion_routing", - "--host", - "127.0.0.1", - "--port", - str(port), - "--worker-urls", - *worker_urls, - "--routing-algorithm", - "least-request", - "--log-level", - "warning", - ] - proc = subprocess.Popen( - cmd, - env=_env(), - start_new_session=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - url = f"http://127.0.0.1:{port}" - try: - _wait_healthy(url, 30, label="router", proc=proc) - except Exception: - _kill_proc(proc) - raise - - yield url - _kill_proc(proc) - - -# -- Tests ------------------------------------------------------------------ - - -class TestSglangHealth: - def test_router_healthy(self, router_url): - r = httpx.get(f"{router_url}/health", timeout=10.0) - assert r.status_code == 200 - assert r.json()["status"] == "healthy" - - def test_workers_listed(self, router_url, sglang_workers): - urls = httpx.get(f"{router_url}/list_workers", timeout=10.0).json()["urls"] - assert len(urls) == len(sglang_workers) - - -class TestSglangImageGeneration: - def test_b64_json_generates_real_image(self, router_url, sglang_config): - """Generate a real image through sglang and verify it's a valid PNG.""" - r = httpx.post( - f"{router_url}/generate", - json={ - "model": sglang_config["model"], - "prompt": "a simple red circle on white background", - "num_images": 1, - "response_format": "b64_json", - }, - timeout=120.0, - ) - assert r.status_code == 200 - body = r.json() - assert "data" in body - assert len(body["data"]) == 1 - - img_bytes = base64.b64decode(body["data"][0]["b64_json"]) - # Verify PNG magic bytes - assert img_bytes[:8] == b"\x89PNG\r\n\x1a\n" - # Real image should be substantially larger than a 1x1 pixel - assert len(img_bytes) > 1000 - - def test_multiple_images(self, router_url, sglang_config): - r = httpx.post( - f"{router_url}/generate", - json={ - "model": sglang_config["model"], - "prompt": "a blue square", - "num_images": 2, - "response_format": "b64_json", - }, - timeout=120.0, - ) - assert r.status_code == 200 - assert len(r.json()["data"]) == 2 - - def test_url_format(self, router_url, sglang_config): - """Generate with response_format=url (requires worker file storage).""" - r = httpx.post( - f"{router_url}/generate", - json={ - "model": sglang_config["model"], - "prompt": "a green triangle", - "num_images": 1, - "response_format": "url", - }, - timeout=120.0, - ) - # url format may require cloud storage config — accept either success or - # a clear error about storage, not a crash - assert r.status_code in (200, 400, 500) - if r.status_code == 200: - assert "url" in r.json()["data"][0] - - -class TestSglangLoadBalancing: - def test_requests_distributed(self, router_url, sglang_workers, sglang_config): - """With multiple workers, requests should be distributed.""" - if len(sglang_workers) < 2: - pytest.skip("Need at least 2 workers for load balancing test") - - for _ in range(4): - r = httpx.post( - f"{router_url}/generate", - json={ - "model": sglang_config["model"], - "prompt": "test", - "num_images": 1, - "response_format": "b64_json", - }, - timeout=120.0, - ) - assert r.status_code == 200 - - # Verify health shows all workers still alive - health = httpx.get(f"{router_url}/health_workers", timeout=10.0).json() - assert all(not w["is_dead"] for w in health["workers"]) - - -class TestSglangProxy: - def test_get_model_info(self, router_url): - """Proxy should forward GET requests to worker.""" - r = httpx.get(f"{router_url}/get_model_info", timeout=30.0) - # sglang workers expose /get_model_info - assert r.status_code in (200, 404) - - -class TestSglangVideoEndpoint: - def test_generate_video_rejects_image_only_workers(self, router_url): - """Image-only workers (e.g. Qwen/Qwen-Image) should return 400 for /generate_video.""" - r = httpx.post( - f"{router_url}/generate_video", - json={"prompt": "a walking cat", "num_frames": 8}, - timeout=10.0, - ) - assert r.status_code == 400 - body = r.json() - assert "error" in body - assert body["error"] # non-empty error message diff --git a/tests/unit/fake_worker.py b/tests/unit/fake_worker.py deleted file mode 100644 index be8f896..0000000 --- a/tests/unit/fake_worker.py +++ /dev/null @@ -1,167 +0,0 @@ -#!/usr/bin/env python3 -""" -Fake sglang diffusion worker for e2e testing. - -Implements the same HTTP API contract as a real sglang diffusion worker, -but returns canned responses without any GPU or model dependencies. - -Usage: - python tests/unit/fake_worker.py --port 19000 - python tests/unit/fake_worker.py --port 19000 --fail-rate 0.5 - python tests/unit/fake_worker.py --port 19000 --latency 0.2 -""" - -from __future__ import annotations - -import argparse -import asyncio -import base64 -import random -import time - -import uvicorn -from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse - -# 1x1 red PNG pixel -_TINY_PNG = base64.b64decode( - "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4" - "2mP8z8BQDwADhQGAWjR9awAAAABJRU5ErkJggg==" -) -_TINY_PNG_B64 = base64.b64encode(_TINY_PNG).decode() - - -def create_app( - fail_rate: float = 0.0, - latency: float = 0.0, - worker_id: str = "fake-worker", - task_type: str = "T2V", -) -> FastAPI: - app = FastAPI() - request_count = {"total": 0, "generate": 0, "video": 0, "weights": 0} - - @app.get("/health") - async def health(): - return {"status": "ok", "worker_id": worker_id} - - @app.get("/v1/models") - async def list_models(): - return { - "object": "list", - "data": [ - { - "id": "fake-model", - "task_type": task_type, - } - ], - } - - @app.post("/v1/images/generations") - async def generate_image(request: Request): - request_count["total"] += 1 - request_count["generate"] += 1 - - if latency > 0: - await asyncio.sleep(latency) - - if fail_rate > 0 and random.random() < fail_rate: - return JSONResponse( - status_code=500, - content={"detail": "Simulated worker failure"}, - ) - - body = await request.json() - response_format = body.get("response_format", "url") - prompt = body.get("prompt", "") - n = body.get("n", body.get("num_images", 1)) - - data = [] - for i in range(n): - if response_format == "b64_json": - data.append( - { - "b64_json": _TINY_PNG_B64, - "revised_prompt": prompt, - "index": i, - } - ) - else: - data.append( - { - "url": f"http://localhost/files/img_{request_count['generate']:04d}_{i}.png", - "revised_prompt": prompt, - "index": i, - } - ) - - return { - "created": int(time.time()), - "data": data, - "model": body.get("model", "fake-model"), - "worker_id": worker_id, - } - - @app.post("/v1/videos") - async def generate_video(request: Request): - request_count["total"] += 1 - request_count["video"] += 1 - - if latency > 0: - await asyncio.sleep(latency) - - body = await request.json() - prompt = body.get("prompt", "") - - return { - "created": int(time.time()), - "data": [ - { - "url": f"http://localhost/files/vid_{request_count['video']:04d}.mp4", - "revised_prompt": prompt, - } - ], - "model": body.get("model", "fake-model"), - "worker_id": worker_id, - } - - @app.post("/update_weights_from_disk") - async def update_weights(request: Request): - request_count["total"] += 1 - request_count["weights"] += 1 - body = await request.json() - return { - "ok": True, - "model_path": body.get("model_path", ""), - "worker_id": worker_id, - } - - @app.get("/stats") - async def stats(): - """Test helper: return request counts.""" - return request_count - - return app - - -def main(): - parser = argparse.ArgumentParser(description="Fake sglang diffusion worker") - parser.add_argument("--host", default="127.0.0.1") - parser.add_argument("--port", type=int, required=True) - parser.add_argument("--fail-rate", type=float, default=0.0) - parser.add_argument("--latency", type=float, default=0.0) - parser.add_argument("--worker-id", type=str, default=None) - parser.add_argument("--task-type", type=str, default="T2V") - args = parser.parse_args() - - worker_id = args.worker_id or f"fake-worker-{args.port}" - app = create_app( - fail_rate=args.fail_rate, - latency=args.latency, - worker_id=worker_id, - task_type=args.task_type, - ) - uvicorn.run(app, host=args.host, port=args.port, log_level="warning") - - -if __name__ == "__main__": - main() diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 1899e15..1ef2a85 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -1,14 +1,8 @@ -"""Unit tests for CLI argument parsing. - -Tests the argparse configuration directly — no mocks, no process spawning. -CLI integration (actually starting the router) is covered in e2e tests. -""" - from __future__ import annotations -import pytest +from unittest import mock -from sglang_diffusion_routing.cli.main import build_parser +from sglang_diffusion_routing.cli.main import build_parser, run_cli class TestCLIParser: @@ -25,7 +19,7 @@ def test_defaults(self): assert args.verbose is False assert args.log_level == "info" - def test_full_args(self): + def test_parses_worker_urls(self): args = build_parser().parse_args( [ "--host", @@ -37,14 +31,6 @@ def test_full_args(self): "http://localhost:10092", "--routing-algorithm", "round-robin", - "--timeout", - "0.5", - "--max-connections", - "500", - "--health-check-interval", - "30", - "--health-check-failure-threshold", - "5", "--verbose", "--log-level", "warning", @@ -54,13 +40,17 @@ def test_full_args(self): assert args.port == 31000 assert args.worker_urls == ["http://localhost:10090", "http://localhost:10092"] assert args.routing_algorithm == "round-robin" - assert args.timeout == 0.5 - - def test_rejects_invalid_routing_algorithm(self): - with pytest.raises(SystemExit): - build_parser().parse_args(["--routing-algorithm", "invalid-algo"]) - - def test_accepts_all_valid_algorithms(self): - for algo in ("least-request", "round-robin", "random"): - args = build_parser().parse_args(["--routing-algorithm", algo]) - assert args.routing_algorithm == algo + assert args.verbose is True + assert args.log_level == "warning" + + +def test_run_cli_calls_router_runner(): + with mock.patch("sglang_diffusion_routing.cli.main._run_router_server") as mock_run: + code = run_cli(["--port", "30123", "--worker-urls", "http://localhost:10090"]) + assert code == 0 + mock_run.assert_called_once() + args = mock_run.call_args.args[0] + assert args.port == 30123 + assert args.worker_urls == ["http://localhost:10090"] + assert mock_run.call_args.kwargs["worker_urls"] == ["http://localhost:10090"] + assert mock_run.call_args.kwargs["log_prefix"] == "[sglang-d-router]" diff --git a/tests/unit/test_diffusion_router.py b/tests/unit/test_diffusion_router.py index 880a2ee..3574fe0 100644 --- a/tests/unit/test_diffusion_router.py +++ b/tests/unit/test_diffusion_router.py @@ -1,20 +1,15 @@ -"""Unit tests for DiffusionRouter core functionality. - -Tests routing algorithms, count management, worker registration, -and response building. All tests call real code directly — no mocks, -no HTTP, no fake interfaces. -""" - import asyncio import json from argparse import Namespace +from types import SimpleNamespace import pytest from sglang_diffusion_routing import DiffusionRouter -def _make_args(**overrides) -> Namespace: +def make_router_args(**overrides) -> Namespace: + """Create a Namespace with default DiffusionRouter args, applying overrides.""" defaults = dict( host="127.0.0.1", port=30080, @@ -28,32 +23,30 @@ def _make_args(**overrides) -> Namespace: @pytest.fixture def router_factory(): - created: list[DiffusionRouter] = [] - - def _create(workers=None, dead=None, **kw): - r = DiffusionRouter(_make_args(**kw)) - if workers is not None: - r.worker_request_counts = dict(workers) - r.worker_failure_counts = {u: 0 for u in workers} + """Factory fixture that creates routers and closes their clients at teardown.""" + created_routers: list[DiffusionRouter] = [] + + def _create( + workers: dict[str, int], + dead: set[str] | None = None, + **arg_overrides, + ) -> DiffusionRouter: + router = DiffusionRouter(make_router_args(**arg_overrides)) + router.worker_request_counts = dict(workers) + router.worker_failure_counts = {url: 0 for url in workers} if dead: - r.dead_workers = set(dead) - created.append(r) - return r + router.dead_workers = set(dead) + created_routers.append(router) + return router yield _create - for r in created: - asyncio.run(r.client.aclose()) - -# ── Routing algorithms ──────────────────────────────────────────────── + for router in created_routers: + asyncio.run(router.client.aclose()) class TestLeastRequest: - def test_picks_min_load(self, router_factory): - r = router_factory( - {"http://w1:8000": 5, "http://w2:8000": 2, "http://w3:8000": 8} - ) - assert r._select_worker_by_routing() == "http://w2:8000" + """Test the least-request (default) load-balancing algorithm.""" def test_selects_min_load(self, router_factory): router = router_factory( @@ -74,71 +67,77 @@ def test_excludes_dead_workers(self, router_factory): class TestRoundRobin: - def test_cycles_in_order(self, router_factory): - r = router_factory( + """Test the round-robin load-balancing algorithm.""" + + def test_cycles_workers(self, router_factory): + router = router_factory( {"http://w1:8000": 0, "http://w2:8000": 0, "http://w3:8000": 0}, routing_algorithm="round-robin", ) - results = [r._select_worker_by_routing() for _ in range(6)] - workers = list(r.worker_request_counts.keys()) + results = [router._select_worker_by_routing() for _ in range(6)] + workers = list(router.worker_request_counts.keys()) expected = [workers[i % 3] for i in range(6)] assert results == expected for url in workers: - assert r.worker_request_counts[url] == 2 + assert router.worker_request_counts[url] == 2 - def test_skips_dead_workers(self, router_factory): - r = router_factory( + def test_excludes_dead_workers(self, router_factory): + router = router_factory( {"http://w1:8000": 0, "http://w2:8000": 0, "http://w3:8000": 0}, dead={"http://w2:8000"}, routing_algorithm="round-robin", ) - results = [r._select_worker_by_routing() for _ in range(4)] + results = [router._select_worker_by_routing() for _ in range(4)] assert "http://w2:8000" not in results assert all(url in ("http://w1:8000", "http://w3:8000") for url in results) -class TestRandomRouting: - def test_covers_all_workers(self, router_factory): - r = router_factory( +class TestRandom: + """Test the random load-balancing algorithm.""" + + def test_selects_from_valid_workers(self, router_factory): + router = router_factory( {"http://w1:8000": 0, "http://w2:8000": 0, "http://w3:8000": 0}, routing_algorithm="random", ) seen = set() for _ in range(30): # Reset counts so they do not grow unbounded - for url in r.worker_request_counts: - r.worker_request_counts[url] = 0 - seen.add(r._select_worker_by_routing()) + for url in router.worker_request_counts: + router.worker_request_counts[url] = 0 + seen.add(router._select_worker_by_routing()) assert seen == {"http://w1:8000", "http://w2:8000", "http://w3:8000"} def test_excludes_dead_workers(self, router_factory): - r = router_factory( + router = router_factory( {"http://w1:8000": 0, "http://w2:8000": 0, "http://w3:8000": 0}, dead={"http://w2:8000"}, routing_algorithm="random", ) for _ in range(20): - url = r._select_worker_by_routing() + url = router._select_worker_by_routing() assert url != "http://w2:8000" - r.worker_request_counts[url] -= 1 # reset increment + router.worker_request_counts[url] -= 1 # reset increment -class TestRoutingEdgeCases: - @pytest.mark.parametrize("algo", ["least-request", "round-robin", "random"]) - def test_no_workers_raises(self, router_factory, algo): - r = router_factory({}, routing_algorithm=algo) +class TestErrorCases: + """Test error handling across all routing algorithms.""" + + @pytest.mark.parametrize("algorithm", ["least-request", "round-robin", "random"]) + def test_raises_when_no_workers(self, router_factory, algorithm): + router = router_factory({}, routing_algorithm=algorithm) with pytest.raises(RuntimeError, match="No workers registered"): - r._select_worker_by_routing() + router._select_worker_by_routing() - @pytest.mark.parametrize("algo", ["least-request", "round-robin", "random"]) - def test_all_dead_raises(self, router_factory, algo): - r = router_factory( + @pytest.mark.parametrize("algorithm", ["least-request", "round-robin", "random"]) + def test_raises_when_all_dead(self, router_factory, algorithm): + router = router_factory( {"http://w1:8000": 0, "http://w2:8000": 0}, dead={"http://w1:8000", "http://w2:8000"}, - routing_algorithm=algo, + routing_algorithm=algorithm, ) with pytest.raises(RuntimeError, match="No healthy workers"): - r._select_worker_by_routing() + router._select_worker_by_routing() class TestCountManagement: @@ -153,145 +152,78 @@ def test_increment_and_finish(self, router_factory, algorithm): assert router.worker_request_counts[url] == 0 -# ── Worker registration ────────────────────────────────────────────── - - -class TestRegisterWorker: - def test_registers_and_deduplicates(self, router_factory): - r = router_factory() - r.register_worker("http://w1:8000") - r.register_worker("http://w1:8000") - r.register_worker("http://w2:9000") - assert len(r.worker_request_counts) == 2 - - def test_registered_worker_is_routable(self, router_factory): - r = router_factory() - r.register_worker("http://w1:8000") - assert r._select_worker_by_routing() == "http://w1:8000" +class TestDefaults: + """Test default routing algorithm when the attribute is absent.""" - def test_rejects_blocked_host(self, router_factory): - r = router_factory() - with pytest.raises(ValueError, match="blocked"): - r.register_worker("http://169.254.169.254:80") - assert len(r.worker_request_counts) == 0 - - def test_dead_then_register_new_restores_routing(self, router_factory): - r = router_factory( - {"http://w1:8000": 0}, - dead={"http://w1:8000"}, + def test_default_algorithm_is_least_request(self): + args = Namespace( + host="127.0.0.1", port=30080, max_connections=100, timeout=120.0 ) - with pytest.raises(RuntimeError): - r._select_worker_by_routing() - r.register_worker("http://w2:9000") - assert r._select_worker_by_routing() == "http://w2:9000" - - -# ── Response building ──────────────────────────────────────────────── - + # args has no routing_algorithm attribute + router = DiffusionRouter(args) + try: + assert router.routing_algorithm == "least-request" + finally: + asyncio.run(router.client.aclose()) -class TestBuildProxyResponse: - def test_small_json(self, router_factory): - r = router_factory() - content = json.dumps({"key": "value"}).encode() - resp = r._build_proxy_response( - content, 200, {"content-type": "application/json"} - ) - assert json.loads(resp.body) == {"key": "value"} - assert resp.status_code == 200 - - def test_large_json_returns_raw(self, router_factory): - r = router_factory() - big = json.dumps({"data": "x" * (300 * 1024)}).encode() - resp = r._build_proxy_response(big, 200, {"content-type": "application/json"}) - assert resp.body == big - - def test_preserves_status_code(self, router_factory): - r = router_factory() - content = json.dumps({"error": "not found"}).encode() - resp = r._build_proxy_response( - content, 404, {"content-type": "application/json"} - ) - assert resp.status_code == 404 +class TestRegressions: + def test_forward_body_error_does_not_leak_request_count(self, router_factory): + router = router_factory({"http://w1:8000": 0}) -# ── Static helpers ──────────────────────────────────────────────────── + class BrokenRequest: + method = "POST" + headers = {"content-type": "application/json"} + url = SimpleNamespace(query="") + async def body(self): + raise RuntimeError("body read failed") -class TestSanitizeResponseHeaders: - def test_removes_hop_by_hop_and_encoding(self): - headers = { - "content-type": "application/json", - "connection": "keep-alive", - "transfer-encoding": "chunked", - "content-length": "1234", - "content-encoding": "gzip", - "x-custom": "value", + response = asyncio.run( + router._forward_to_worker(BrokenRequest(), "v1/images/generations") + ) + assert response.status_code == 502 + assert router.worker_request_counts["http://w1:8000"] == 0 + + def test_register_worker_normalizes_duplicate_urls(self, router_factory): + router = router_factory({}) + router.register_worker("http://LOCALHOST:10090/") + router.register_worker("http://localhost:10090") + assert list(router.worker_request_counts.keys()) == ["http://localhost:10090"] + + def test_register_worker_rejects_metadata_host(self, router_factory): + router = router_factory({}) + with pytest.raises(ValueError, match="host is blocked"): + router.register_worker("http://169.254.169.254:80") + + def test_broadcast_to_workers_collects_per_worker_results(self, router_factory): + router = router_factory({"http://w1:8000": 0, "http://w2:8000": 0}) + + class FakeResponse: + def __init__(self, status_code: int, body: dict): + self.status_code = status_code + self._body = body + + async def aread(self) -> bytes: + return json.dumps(self._body).encode("utf-8") + + responses = { + "http://w1:8000/update_weights_from_disk": FakeResponse(200, {"ok": True}), + "http://w2:8000/update_weights_from_disk": FakeResponse( + 500, {"error": "bad worker"} + ), } - result = DiffusionRouter._sanitize_response_headers(headers) - assert set(result.keys()) == {"content-type", "x-custom"} - - -class TestTryDecodeJson: - def test_valid_json(self): - assert DiffusionRouter._try_decode_json(b'{"a": 1}') == {"a": 1} - def test_invalid_json_returns_raw(self): - result = DiffusionRouter._try_decode_json(b"not json") - assert "raw" in result - - -# ── Constructor ─────────────────────────────────────────────────────── - - -class TestConstructor: - def test_default_algorithm(self): - args = Namespace( - host="127.0.0.1", port=30080, max_connections=100, timeout=120.0 - ) - r = DiffusionRouter(args) - try: - assert r.routing_algorithm == "least-request" - finally: - asyncio.run(r.client.aclose()) + async def fake_post(url, content, headers): + del content, headers + return responses[url] - def test_none_timeout_defaults_to_120(self): - args = Namespace( - host="127.0.0.1", - port=30080, - max_connections=100, - timeout=None, - routing_algorithm="least-request", + router.client.post = fake_post # type: ignore[assignment] + result = asyncio.run( + router._broadcast_to_workers("update_weights_from_disk", b"{}", {}) ) - r = DiffusionRouter(args) - try: - assert r.client.timeout.connect == 120.0 - finally: - asyncio.run(r.client.aclose()) - - def test_initial_state_empty(self): - args = _make_args() - r = DiffusionRouter(args) - try: - assert r.worker_request_counts == {} - assert r.dead_workers == set() - finally: - asyncio.run(r.client.aclose()) - - def test_all_routes_registered(self): - args = _make_args() - r = DiffusionRouter(args) - try: - routes = [route.path for route in r.app.routes] - for path in [ - "/add_worker", - "/list_workers", - "/health", - "/health_workers", - "/generate", - "/generate_video", - "/update_weights_from_disk", - "/{path:path}", - ]: - assert path in routes - finally: - asyncio.run(r.client.aclose()) + assert len(result) == 2 + assert {item["worker_url"] for item in result} == { + "http://w1:8000", + "http://w2:8000", + } diff --git a/tests/unit/test_fake_e2e.py b/tests/unit/test_fake_e2e.py deleted file mode 100644 index 6bca6f6..0000000 --- a/tests/unit/test_fake_e2e.py +++ /dev/null @@ -1,824 +0,0 @@ -""" -CPU-only fake end-to-end tests with real processes. - -Spins up fake worker processes and a real router process, then sends -real HTTP requests through the full stack: - - pytest client -> router process (port) -> fake worker processes (ports) - -No mocks, no monkey-patching. All communication over real TCP sockets. - -Run: - pytest tests/unit/test_fake_e2e.py -v -s -""" - -from __future__ import annotations - -import base64 -import os -import signal -import socket -import subprocess -import sys -import time -from pathlib import Path - -import httpx -import pytest - -REPO_ROOT = Path(__file__).resolve().parents[2] -FAKE_WORKER_SCRIPT = Path(__file__).resolve().parent / "fake_worker.py" -PYTHON = sys.executable - - -# -- Helpers --------------------------------------------------------------- - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -def _wait_healthy(url: str, timeout: float = 10.0) -> None: - deadline = time.monotonic() + timeout - while time.monotonic() < deadline: - try: - r = httpx.get(f"{url}/health", timeout=2.0) - if r.status_code == 200: - return - except httpx.HTTPError: - pass - time.sleep(0.2) - raise TimeoutError(f"Service at {url} not healthy within {timeout}s") - - -def _wait_responding(url: str, timeout: float = 10.0) -> None: - """Wait for any HTTP response (even 503).""" - deadline = time.monotonic() + timeout - while time.monotonic() < deadline: - try: - httpx.get(f"{url}/health", timeout=2.0) - return - except httpx.HTTPError: - pass - time.sleep(0.2) - raise TimeoutError(f"Service at {url} not responding within {timeout}s") - - -def _kill_proc(proc: subprocess.Popen) -> None: - if proc.poll() is not None: - return - try: - os.killpg(proc.pid, signal.SIGTERM) - except (ProcessLookupError, PermissionError): - pass - try: - proc.wait(timeout=5) - except subprocess.TimeoutExpired: - try: - os.killpg(proc.pid, signal.SIGKILL) - except (ProcessLookupError, PermissionError): - pass - proc.wait(timeout=3) - - -def _env() -> dict[str, str]: - env = os.environ.copy() - src = str(REPO_ROOT / "src") - old = env.get("PYTHONPATH", "") - env["PYTHONPATH"] = f"{src}:{old}" if old else src - return env - - -def _start_worker(worker_id: str, **kw) -> tuple[subprocess.Popen, str]: - port = _find_free_port() - cmd = [ - PYTHON, - str(FAKE_WORKER_SCRIPT), - "--port", - str(port), - "--worker-id", - worker_id, - ] - for k, v in kw.items(): - cmd += [f"--{k.replace('_', '-')}", str(v)] - proc = subprocess.Popen( - cmd, - env=_env(), - start_new_session=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - return proc, f"http://127.0.0.1:{port}" - - -def _start_router(worker_urls: list[str], **kw) -> tuple[subprocess.Popen, str]: - port = _find_free_port() - cmd = [ - PYTHON, - "-m", - "sglang_diffusion_routing", - "--host", - "127.0.0.1", - "--port", - str(port), - "--health-check-interval", - "3600", - "--log-level", - "warning", - ] - if worker_urls: - cmd += ["--worker-urls", *worker_urls] - for k, v in kw.items(): - cmd += [f"--{k.replace('_', '-')}", str(v)] - proc = subprocess.Popen( - cmd, - env=_env(), - start_new_session=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - return proc, f"http://127.0.0.1:{port}" - - -# -- Fixtures -------------------------------------------------------------- - - -class FakeWorker: - def __init__(self, proc, url, worker_id): - self.proc, self.url, self.worker_id = proc, url, worker_id - - def stats(self) -> dict: - return httpx.get(f"{self.url}/stats", timeout=3.0).json() - - -@pytest.fixture(scope="module") -def fake_workers(): - workers, procs = [], [] - for i in range(2): - proc, url = _start_worker(f"worker-{i}") - procs.append(proc) - workers.append(FakeWorker(proc, url, f"worker-{i}")) - for w in workers: - _wait_healthy(w.url) - yield workers - for p in procs: - _kill_proc(p) - - -@pytest.fixture(scope="module") -def router_url(fake_workers): - proc, url = _start_router( - [w.url for w in fake_workers], routing_algorithm="round-robin" - ) - try: - _wait_healthy(url) - except TimeoutError: - _kill_proc(proc) - raise - yield url - _kill_proc(proc) - - -# -- Health & status ------------------------------------------------------- - - -class TestHealth: - def test_healthy_with_workers(self, router_url): - r = httpx.get(f"{router_url}/health", timeout=5.0) - assert r.status_code == 200 - body = r.json() - assert body["status"] == "healthy" - assert body["healthy_workers"] == 2 - assert body["total_workers"] == 2 - - def test_health_workers_detail(self, router_url): - workers = httpx.get(f"{router_url}/health_workers", timeout=5.0).json()[ - "workers" - ] - assert len(workers) == 2 - for w in workers: - assert not w["is_dead"] - assert "url" in w - assert "active_requests" in w - assert w["consecutive_failures"] == 0 - - def test_list_workers(self, router_url): - urls = httpx.get(f"{router_url}/list_workers", timeout=5.0).json()["urls"] - assert len(urls) == 2 - assert all(u.startswith("http://") for u in urls) - - def test_unhealthy_when_no_workers(self): - proc, url = _start_router([]) - try: - _wait_responding(url) - r = httpx.get(f"{url}/health", timeout=5.0) - assert r.status_code == 503 - body = r.json() - assert body["status"] == "unhealthy" - assert body["healthy_workers"] == 0 - assert body["total_workers"] == 0 - finally: - _kill_proc(proc) - - -# -- Worker registration --------------------------------------------------- - - -class TestWorkerRegistration: - def test_add_via_query_param(self, fake_workers): - proc, rurl = _start_router([]) - try: - _wait_responding(rurl) - r = httpx.post( - f"{rurl}/add_worker", params={"url": fake_workers[0].url}, timeout=5.0 - ) - assert r.status_code == 200 - assert fake_workers[0].url in r.json()["worker_urls"] - finally: - _kill_proc(proc) - - def test_add_via_json_body(self, fake_workers): - proc, rurl = _start_router([]) - try: - _wait_responding(rurl) - r = httpx.post( - f"{rurl}/add_worker", json={"url": fake_workers[0].url}, timeout=5.0 - ) - assert r.status_code == 200 - assert r.json()["status"] == "success" - finally: - _kill_proc(proc) - - def test_add_deduplicates(self, fake_workers): - proc, rurl = _start_router([]) - try: - _wait_responding(rurl) - url = fake_workers[0].url - httpx.post(f"{rurl}/add_worker", params={"url": url + "/"}, timeout=5.0) - httpx.post(f"{rurl}/add_worker", params={"url": url}, timeout=5.0) - assert ( - len(httpx.get(f"{rurl}/list_workers", timeout=5.0).json()["urls"]) == 1 - ) - finally: - _kill_proc(proc) - - def test_missing_url_400(self): - proc, rurl = _start_router([]) - try: - _wait_responding(rurl) - r = httpx.post(f"{rurl}/add_worker", timeout=5.0) - assert r.status_code == 400 - assert "error" in r.json() - finally: - _kill_proc(proc) - - def test_blocked_host_400(self): - proc, rurl = _start_router([]) - try: - _wait_responding(rurl) - r = httpx.post( - f"{rurl}/add_worker", - params={"url": "http://169.254.169.254:80"}, - timeout=5.0, - ) - assert r.status_code == 400 - assert "blocked" in r.json()["error"] - finally: - _kill_proc(proc) - - def test_invalid_json_400(self): - proc, rurl = _start_router([]) - try: - _wait_responding(rurl) - r = httpx.post( - f"{rurl}/add_worker", - content=b"bad json", - headers={"content-type": "application/json"}, - timeout=5.0, - ) - assert r.status_code == 400 - finally: - _kill_proc(proc) - - def test_dynamic_worker_receives_traffic(self, fake_workers): - """Add a worker dynamically and verify it actually receives requests.""" - w_proc, w_url = _start_worker("dynamic") - r_proc, rurl = _start_router([]) - try: - _wait_healthy(w_url) - _wait_responding(rurl) - httpx.post(f"{rurl}/add_worker", params={"url": w_url}, timeout=5.0) - assert ( - len(httpx.get(f"{rurl}/list_workers", timeout=5.0).json()["urls"]) == 1 - ) - r = httpx.post( - f"{rurl}/generate", - json={"prompt": "t", "response_format": "b64_json"}, - timeout=10.0, - ) - assert r.status_code == 200 - assert r.json()["worker_id"] == "dynamic" - finally: - _kill_proc(w_proc) - _kill_proc(r_proc) - - -# -- Image generation ------------------------------------------------------ - - -class TestImageGeneration: - def test_b64_json_returns_valid_png(self, router_url): - r = httpx.post( - f"{router_url}/generate", - json={ - "model": "test-model", - "prompt": "cat", - "num_images": 1, - "response_format": "b64_json", - }, - timeout=10.0, - ) - assert r.status_code == 200 - body = r.json() - assert "data" in body - assert "created" in body - img = base64.b64decode(body["data"][0]["b64_json"]) - assert img[:4] == b"\x89PNG" - - def test_url_format(self, router_url): - r = httpx.post( - f"{router_url}/generate", - json={ - "model": "t", - "prompt": "dog", - "num_images": 1, - "response_format": "url", - }, - timeout=10.0, - ) - data = r.json()["data"][0] - assert "url" in data - assert data["url"].startswith("http") - - def test_multiple_images(self, router_url): - r = httpx.post( - f"{router_url}/generate", - json={ - "model": "t", - "prompt": "x", - "num_images": 3, - "response_format": "b64_json", - }, - timeout=10.0, - ) - data = r.json()["data"] - assert len(data) == 3 - # Each image should have an index - indices = [d["index"] for d in data] - assert sorted(indices) == [0, 1, 2] - - def test_prompt_preserved_in_response(self, router_url): - prompt = "a beautiful sunset over the ocean" - r = httpx.post( - f"{router_url}/generate", - json={ - "model": "t", - "prompt": prompt, - "num_images": 1, - "response_format": "b64_json", - }, - timeout=10.0, - ) - assert r.json()["data"][0]["revised_prompt"] == prompt - - def test_model_field_preserved(self, router_url): - r = httpx.post( - f"{router_url}/generate", - json={ - "model": "my-custom-model", - "prompt": "x", - "num_images": 1, - "response_format": "b64_json", - }, - timeout=10.0, - ) - assert r.json()["model"] == "my-custom-model" - - def test_no_workers_503(self): - proc, rurl = _start_router([]) - try: - _wait_responding(rurl) - r = httpx.post(f"{rurl}/generate", json={"prompt": "t"}, timeout=5.0) - assert r.status_code == 503 - assert "error" in r.json() - finally: - _kill_proc(proc) - - -# -- Video generation ------------------------------------------------------ - - -class TestVideoGeneration: - def test_generate_video(self, router_url): - r = httpx.post( - f"{router_url}/generate_video", - json={"model": "vid-model", "prompt": "river"}, - timeout=10.0, - ) - assert r.status_code == 200 - body = r.json() - assert "url" in body["data"][0] - assert "created" in body - - def test_video_prompt_preserved(self, router_url): - prompt = "a flowing river in autumn" - r = httpx.post( - f"{router_url}/generate_video", - json={"model": "vid", "prompt": prompt}, - timeout=10.0, - ) - assert r.json()["data"][0]["revised_prompt"] == prompt - - def test_no_workers_503(self): - proc, rurl = _start_router([]) - try: - _wait_responding(rurl) - r = httpx.post(f"{rurl}/generate_video", json={"prompt": "t"}, timeout=5.0) - assert r.status_code == 503 - finally: - _kill_proc(proc) - - -# -- Weight update broadcast ----------------------------------------------- - - -class TestUpdateWeights: - def test_broadcasts_to_all(self, router_url): - r = httpx.post( - f"{router_url}/update_weights_from_disk", - json={"model_path": "/weights/v2"}, - timeout=10.0, - ) - results = r.json()["results"] - assert len(results) == 2 - for res in results: - assert res["status_code"] == 200 - assert res["body"]["ok"] is True - assert res["body"]["model_path"] == "/weights/v2" - assert "worker_url" in res - - def test_empty_pool(self): - proc, rurl = _start_router([]) - try: - _wait_responding(rurl) - r = httpx.post( - f"{rurl}/update_weights_from_disk", - json={"model_path": "x"}, - timeout=5.0, - ) - assert r.json()["results"] == [] - finally: - _kill_proc(proc) - - -# -- Load balancing -------------------------------------------------------- - - -class TestRoundRobinBalancing: - """Tests using the module-scoped round-robin router.""" - - def test_distributes_evenly(self, router_url, fake_workers): - initial = [w.stats()["generate"] for w in fake_workers] - for _ in range(10): - assert ( - httpx.post( - f"{router_url}/generate", - json={"prompt": "t", "response_format": "b64_json"}, - timeout=10.0, - ).status_code - == 200 - ) - final = [w.stats()["generate"] for w in fake_workers] - deltas = [final[i] - initial[i] for i in range(2)] - assert all(d > 0 for d in deltas) - assert sum(deltas) == 10 - # Round-robin should be perfectly even with 2 workers - assert deltas[0] == deltas[1] == 5 - - def test_worker_id_proves_real_routing(self, router_url): - """Consecutive requests should alternate between workers.""" - ids = [] - for _ in range(4): - r = httpx.post( - f"{router_url}/generate", - json={"prompt": "t", "response_format": "b64_json"}, - timeout=10.0, - ) - ids.append(r.json()["worker_id"]) - # Round-robin: should alternate - assert ids[0] != ids[1] - assert ids[0] == ids[2] - assert ids[1] == ids[3] - - -class TestLeastRequestBalancing: - """Tests with a dedicated least-request router.""" - - def test_prefers_less_loaded_worker(self, fake_workers): - """With one slow worker and one fast worker, least-request should - send more traffic to the fast one.""" - slow_proc, slow_url = _start_worker("slow", latency=0.3) - fast_proc, fast_url = _start_worker("fast", latency=0.0) - r_proc, rurl = _start_router( - [slow_url, fast_url], - routing_algorithm="least-request", - ) - try: - _wait_healthy(slow_url) - _wait_healthy(fast_url) - _wait_healthy(rurl) - - # Send requests concurrently — fast worker should get more - import concurrent.futures - - def send_one(): - return httpx.post( - f"{rurl}/generate", - json={"prompt": "t", "response_format": "b64_json"}, - timeout=15.0, - ).json()["worker_id"] - - with concurrent.futures.ThreadPoolExecutor(max_workers=6) as pool: - futures = [pool.submit(send_one) for _ in range(12)] - results = [f.result() for f in futures] - - fast_count = results.count("fast") - slow_count = results.count("slow") - # Fast worker should handle more requests than slow - assert fast_count > slow_count - assert fast_count + slow_count == 12 - finally: - _kill_proc(slow_proc) - _kill_proc(fast_proc) - _kill_proc(r_proc) - - -# -- Proxy catch-all ------------------------------------------------------- - - -class TestProxy: - def test_stats_proxied(self, router_url): - r = httpx.get(f"{router_url}/stats", timeout=5.0) - assert r.status_code == 200 - body = r.json() - assert "total" in body - assert "generate" in body - assert "video" in body - - def test_no_workers_503(self): - proc, rurl = _start_router([]) - try: - _wait_responding(rurl) - r = httpx.get(f"{rurl}/any/path", timeout=5.0) - assert r.status_code == 503 - finally: - _kill_proc(proc) - - def test_worker_health_proxied(self, router_url): - """The catch-all proxy should forward /worker_health to a worker.""" - # /stats is handled by fake_worker, verifying proxy works for arbitrary paths - r = httpx.get(f"{router_url}/stats", timeout=5.0) - assert r.status_code == 200 - - -# -- Worker failure -------------------------------------------------------- - - -class TestWorkerFailure: - def test_unreachable_worker_502(self): - proc, rurl = _start_router(["http://127.0.0.1:1"]) - try: - _wait_responding(rurl) - r = httpx.post(f"{rurl}/generate", json={"prompt": "t"}, timeout=10.0) - assert r.status_code == 502 - assert "error" in r.json() - finally: - _kill_proc(proc) - - def test_worker_500_forwarded(self): - w_proc, w_url = _start_worker("failing", fail_rate=1.0) - r_proc, rurl = _start_router([w_url]) - try: - _wait_healthy(w_url) - _wait_healthy(rurl) - r = httpx.post( - f"{rurl}/generate", - json={"prompt": "t", "response_format": "b64_json"}, - timeout=10.0, - ) - assert r.status_code == 500 - assert "detail" in r.json() - finally: - _kill_proc(w_proc) - _kill_proc(r_proc) - - def test_worker_killed_returns_502(self, fake_workers): - """Kill a worker process after router starts, verify 502.""" - w_proc, w_url = _start_worker("ephemeral") - r_proc, rurl = _start_router([w_url]) - try: - _wait_healthy(w_url) - _wait_healthy(rurl) - # Verify it works first - assert ( - httpx.post( - f"{rurl}/generate", - json={"prompt": "t", "response_format": "b64_json"}, - timeout=10.0, - ).status_code - == 200 - ) - # Kill the worker - _kill_proc(w_proc) - time.sleep(0.5) - # Now requests should fail - r = httpx.post( - f"{rurl}/generate", - json={"prompt": "t", "response_format": "b64_json"}, - timeout=10.0, - ) - assert r.status_code == 502 - finally: - _kill_proc(w_proc) - _kill_proc(r_proc) - - -# -- Concurrent requests --------------------------------------------------- - - -class TestConcurrency: - def test_concurrent_requests_all_succeed(self, router_url): - """Fire multiple requests concurrently, all should succeed.""" - import concurrent.futures - - def send_one(i): - return httpx.post( - f"{router_url}/generate", - json={"prompt": f"concurrent-{i}", "response_format": "b64_json"}, - timeout=15.0, - ) - - with concurrent.futures.ThreadPoolExecutor(max_workers=8) as pool: - futures = [pool.submit(send_one, i) for i in range(16)] - responses = [f.result() for f in futures] - - assert all(r.status_code == 200 for r in responses) - # Verify each response has valid data - for r in responses: - body = r.json() - assert "data" in body - assert len(body["data"]) == 1 - - def test_concurrent_mixed_endpoints(self, router_url): - """Concurrent requests to different endpoints should all work.""" - import concurrent.futures - - def gen_image(): - return httpx.post( - f"{router_url}/generate", - json={"prompt": "img", "response_format": "b64_json"}, - timeout=15.0, - ) - - def gen_video(): - return httpx.post( - f"{router_url}/generate_video", json={"prompt": "vid"}, timeout=15.0 - ) - - def check_health(): - return httpx.get(f"{router_url}/health", timeout=5.0) - - with concurrent.futures.ThreadPoolExecutor(max_workers=6) as pool: - futs = ( - [pool.submit(gen_image) for _ in range(4)] - + [pool.submit(gen_video) for _ in range(4)] - + [pool.submit(check_health) for _ in range(2)] - ) - results = [f.result() for f in futs] - - assert all(r.status_code == 200 for r in results) - - -# -- CLI integration ------------------------------------------------------- - - -class TestCLI: - def test_cli_starts_working_router(self, fake_workers): - port = _find_free_port() - proc = subprocess.Popen( - [ - PYTHON, - "-m", - "sglang_diffusion_routing", - "--host", - "127.0.0.1", - "--port", - str(port), - "--worker-urls", - *[w.url for w in fake_workers], - "--routing-algorithm", - "least-request", - "--log-level", - "warning", - ], - env=_env(), - start_new_session=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - url = f"http://127.0.0.1:{port}" - try: - _wait_healthy(url) - assert httpx.get(f"{url}/health", timeout=5.0).status_code == 200 - r = httpx.post( - f"{url}/generate", - json={"prompt": "cli", "response_format": "b64_json"}, - timeout=10.0, - ) - assert r.status_code == 200 - assert "data" in r.json() - finally: - _kill_proc(proc) - - def test_script_entry_point(self, fake_workers): - """Test the sglang-d-router script entry point.""" - # Find the script next to the Python interpreter in the venv - script = Path(PYTHON).parent / "sglang-d-router" - if not script.exists(): - pytest.skip("sglang-d-router not installed in PATH") - port = _find_free_port() - proc = subprocess.Popen( - [ - str(script), - "--host", - "127.0.0.1", - "--port", - str(port), - "--worker-urls", - *[w.url for w in fake_workers], - "--log-level", - "warning", - ], - env=_env(), - start_new_session=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - url = f"http://127.0.0.1:{port}" - try: - _wait_healthy(url) - assert httpx.get(f"{url}/health", timeout=5.0).status_code == 200 - finally: - _kill_proc(proc) - - def test_help_flag(self): - r = subprocess.run( - [PYTHON, "-m", "sglang_diffusion_routing", "--help"], - capture_output=True, - text=True, - timeout=10, - ) - assert r.returncode == 0 - assert "sglang-d-router" in r.stdout - - def test_verbose_flag(self, fake_workers): - """Router with --verbose should start and work normally.""" - port = _find_free_port() - proc = subprocess.Popen( - [ - PYTHON, - "-m", - "sglang_diffusion_routing", - "--host", - "127.0.0.1", - "--port", - str(port), - "--worker-urls", - *[w.url for w in fake_workers], - "--verbose", - "--log-level", - "warning", - ], - env=_env(), - start_new_session=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - url = f"http://127.0.0.1:{port}" - try: - _wait_healthy(url) - assert httpx.get(f"{url}/health", timeout=5.0).json()["status"] == "healthy" - finally: - _kill_proc(proc) diff --git a/tests/unit/test_router_endpoints.py b/tests/unit/test_router_endpoints.py new file mode 100644 index 0000000..2983b8d --- /dev/null +++ b/tests/unit/test_router_endpoints.py @@ -0,0 +1,68 @@ +from argparse import Namespace + +from fastapi.testclient import TestClient + +from sglang_diffusion_routing import DiffusionRouter + + +def make_router_args(**overrides) -> Namespace: + defaults = dict( + host="127.0.0.1", + port=30080, + max_connections=100, + timeout=120.0, + routing_algorithm="least-request", + health_check_interval=3600, + health_check_failure_threshold=3, + ) + defaults.update(overrides) + return Namespace(**defaults) + + +def test_add_worker_normalizes_and_deduplicates(): + router = DiffusionRouter(make_router_args()) + with TestClient(router.app) as client: + first = client.post("/add_worker", params={"url": "http://LOCALHOST:10090/"}) + assert first.status_code == 200 + + second = client.post("/add_worker", params={"url": "http://localhost:10090"}) + assert second.status_code == 200 + payload = second.json() + assert payload["worker_urls"] == ["http://localhost:10090"] + + listed = client.get("/list_workers") + assert listed.status_code == 200 + assert listed.json()["urls"] == ["http://localhost:10090"] + + +def test_add_worker_rejects_blocked_metadata_host(): + router = DiffusionRouter(make_router_args()) + with TestClient(router.app) as client: + response = client.post( + "/add_worker", params={"url": "http://169.254.169.254:80"} + ) + assert response.status_code == 400 + assert "blocked" in response.json()["error"] + + +def test_update_weights_from_disk_returns_broadcast_results(): + router = DiffusionRouter(make_router_args()) + router.register_worker("http://localhost:10090") + + async def fake_broadcast(path: str, body: bytes, headers: dict): + assert path == "update_weights_from_disk" + assert body == b'{"model_path":"abc"}' + assert headers.get("content-type", "").startswith("application/json") + return [ + { + "worker_url": "http://localhost:10090", + "status_code": 200, + "body": {"ok": True}, + } + ] + + router._broadcast_to_workers = fake_broadcast # type: ignore[assignment] + with TestClient(router.app) as client: + response = client.post("/update_weights_from_disk", json={"model_path": "abc"}) + assert response.status_code == 200 + assert response.json()["results"][0]["status_code"] == 200