[WIP] ci: add more testcase#25
Conversation
Summary of ChangesHello @alphabetc1, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly strengthens the testing framework for the sglang-diffusion-routing project. It introduces robust end-to-end tests that validate the router's behavior with both real and simulated SGLang workers, ensuring reliable operation of load balancing, request routing, and worker management. The changes also include minor refactorings to improve the router's startup logic and error handling, making the system more resilient and easier to maintain. Highlights
Changelog
Ignored Files
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces new test cases for both unit and end-to-end scenarios, significantly improving test coverage. It also refactors the worker video support probing logic to occur within the background health check, rather than during router startup, which is a good improvement for robustness. Additionally, a check for no registered workers has been added to the generate_video endpoint. The changes are well-structured and enhance the overall stability and testability of the application.
| if not self.worker_request_counts: | ||
| return JSONResponse( | ||
| status_code=503, | ||
| content={"error": "No workers registered in the pool"}, | ||
| ) |
| for url in worker_urls: | ||
| normalized_url = router.normalize_worker_url(url) | ||
| router.register_worker(normalized_url) |
| # Probe video capability for pre-registered workers in the running event loop. | ||
| unknown_workers = [ | ||
| url for url, support in self.worker_video_support.items() if support is None | ||
| ] | ||
| if unknown_workers: | ||
| await asyncio.gather( | ||
| *(self.refresh_worker_video_support(url) for url in unknown_workers), | ||
| return_exceptions=True, | ||
| ) |
There was a problem hiding this comment.
| """ | ||
| End-to-end tests with real sglang diffusion workers. | ||
|
|
||
| Requires: | ||
| - sglang installed with diffusion support: pip install "sglang[diffusion]" | ||
| - At least 1 GPU available | ||
| - Model weights accessible (downloads on first run) | ||
|
|
||
| These tests are SKIPPED automatically when sglang or GPU is not available. | ||
| To run explicitly: | ||
|
|
||
| pytest tests/e2e/test_e2e_sglang.py -v -s | ||
|
|
||
| Override model/GPU config via environment variables: | ||
| SGLANG_TEST_MODEL Model path (default: Qwen/Qwen-Image) | ||
| SGLANG_TEST_NUM_GPUS GPUs per worker (default: 1) | ||
| SGLANG_TEST_NUM_WORKERS Number of workers (default: 2) | ||
| SGLANG_TEST_TIMEOUT Startup timeout in seconds (default: 600) | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import base64 | ||
| import os | ||
| import shutil | ||
| import signal | ||
| import socket | ||
| import subprocess | ||
| import sys | ||
| import time | ||
| from pathlib import Path | ||
|
|
||
| import httpx | ||
| import pytest | ||
|
|
||
| REPO_ROOT = Path(__file__).resolve().parents[2] | ||
| PYTHON = sys.executable | ||
|
|
||
| DEFAULT_MODEL = "Qwen/Qwen-Image" | ||
| DEFAULT_NUM_GPUS = 1 | ||
| DEFAULT_NUM_WORKERS = 2 | ||
| DEFAULT_TIMEOUT = 600 # sglang model loading can be slow | ||
|
|
||
|
|
||
| def _has_sglang() -> bool: | ||
| try: | ||
| import sglang # noqa: F401 | ||
|
|
||
| return True | ||
| except ImportError: | ||
| return False | ||
|
|
||
|
|
||
| def _gpu_count() -> int: | ||
| try: | ||
| import torch | ||
|
|
||
| return torch.cuda.device_count() | ||
| except Exception: | ||
| return 0 | ||
|
|
||
|
|
||
| def _get_env(key: str, default: str) -> str: | ||
| return os.environ.get(key, default) | ||
|
|
||
|
|
||
| def _find_free_port() -> int: | ||
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | ||
| s.bind(("127.0.0.1", 0)) | ||
| return s.getsockname()[1] | ||
|
|
||
|
|
||
| def _env() -> dict[str, str]: | ||
| env = os.environ.copy() | ||
| src = str(REPO_ROOT / "src") | ||
| old = env.get("PYTHONPATH", "") | ||
| env["PYTHONPATH"] = f"{src}:{old}" if old else src | ||
| return env | ||
|
|
||
|
|
||
| def _wait_healthy( | ||
| url: str, timeout: float, label: str = "", proc: subprocess.Popen | None = None | ||
| ) -> None: | ||
| deadline = time.monotonic() + timeout | ||
| last_log = 0.0 | ||
| while time.monotonic() < deadline: | ||
| if proc is not None and proc.poll() is not None: | ||
| raise RuntimeError( | ||
| f"{label} process exited with code {proc.returncode} during startup" | ||
| ) | ||
| try: | ||
| r = httpx.get(f"{url}/health", timeout=5.0) | ||
| if r.status_code == 200: | ||
| return | ||
| except httpx.HTTPError: | ||
| pass | ||
| now = time.monotonic() | ||
| if now - last_log >= 30: | ||
| elapsed = now - (deadline - timeout) | ||
| print(f" Waiting for {label}... ({elapsed:.0f}s)", flush=True) | ||
| last_log = now | ||
| time.sleep(2) | ||
| raise TimeoutError(f"{label} at {url} not healthy within {timeout}s") | ||
|
|
||
|
|
||
| def _kill_proc(proc: subprocess.Popen) -> None: | ||
| if proc.poll() is not None: | ||
| return | ||
| try: | ||
| os.killpg(proc.pid, signal.SIGTERM) | ||
| except (ProcessLookupError, PermissionError): | ||
| pass | ||
| try: | ||
| proc.wait(timeout=15) | ||
| except subprocess.TimeoutExpired: | ||
| try: | ||
| os.killpg(proc.pid, signal.SIGKILL) | ||
| except (ProcessLookupError, PermissionError): | ||
| pass | ||
| proc.wait(timeout=5) | ||
|
|
||
|
|
||
| # -- Skip conditions ------------------------------------------------------- | ||
|
|
||
| _skip_no_sglang = pytest.mark.skipif( | ||
| not _has_sglang() or shutil.which("sglang") is None, | ||
| reason="sglang not installed or 'sglang' CLI not in PATH", | ||
| ) | ||
| _skip_no_gpu = pytest.mark.skipif( | ||
| _gpu_count() == 0, | ||
| reason="No GPU available", | ||
| ) | ||
|
|
||
| pytestmark = [_skip_no_sglang, _skip_no_gpu] | ||
|
|
||
|
|
||
| # -- Fixtures --------------------------------------------------------------- | ||
|
|
||
|
|
||
| class SglangWorker: | ||
| def __init__(self, proc: subprocess.Popen, url: str): | ||
| self.proc = proc | ||
| self.url = url | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def sglang_config(): | ||
| model = _get_env("SGLANG_TEST_MODEL", DEFAULT_MODEL) | ||
| num_gpus = int(_get_env("SGLANG_TEST_NUM_GPUS", str(DEFAULT_NUM_GPUS))) | ||
| num_workers = int(_get_env("SGLANG_TEST_NUM_WORKERS", str(DEFAULT_NUM_WORKERS))) | ||
| timeout = int(_get_env("SGLANG_TEST_TIMEOUT", str(DEFAULT_TIMEOUT))) | ||
|
|
||
| gpus_available = _gpu_count() | ||
| needed = num_workers * num_gpus | ||
| if gpus_available < needed: | ||
| pytest.skip( | ||
| f"Need {needed} GPUs ({num_workers} workers x {num_gpus} GPUs), " | ||
| f"only {gpus_available} available" | ||
| ) | ||
|
|
||
| return { | ||
| "model": model, | ||
| "num_gpus": num_gpus, | ||
| "num_workers": num_workers, | ||
| "timeout": timeout, | ||
| } | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def sglang_workers(sglang_config): | ||
| """Launch real sglang diffusion worker processes.""" | ||
| workers = [] | ||
| procs = [] | ||
| env = _env() | ||
| gpu_pool = list(range(_gpu_count())) | ||
|
|
||
| for i in range(sglang_config["num_workers"]): | ||
| port = _find_free_port() | ||
| gpu_start = i * sglang_config["num_gpus"] | ||
| gpu_end = gpu_start + sglang_config["num_gpus"] | ||
| gpu_ids = ",".join(str(gpu_pool[g]) for g in range(gpu_start, gpu_end)) | ||
|
|
||
| worker_env = dict(env) | ||
| worker_env["CUDA_VISIBLE_DEVICES"] = gpu_ids | ||
|
|
||
| cmd = [ | ||
| "sglang", | ||
| "serve", | ||
| "--model-path", | ||
| sglang_config["model"], | ||
| "--num-gpus", | ||
| str(sglang_config["num_gpus"]), | ||
| "--host", | ||
| "127.0.0.1", | ||
| "--port", | ||
| str(port), | ||
| ] | ||
|
|
||
| print( | ||
| f"\n[sglang-test] Starting worker {i} on port {port} (GPU: {gpu_ids})", | ||
| flush=True, | ||
| ) | ||
| proc = subprocess.Popen( | ||
| cmd, | ||
| env=worker_env, | ||
| start_new_session=True, | ||
| stdout=subprocess.PIPE, | ||
| stderr=subprocess.PIPE, | ||
| ) | ||
| procs.append(proc) | ||
| workers.append(SglangWorker(proc, f"http://127.0.0.1:{port}")) | ||
|
|
||
| try: | ||
| for w in workers: | ||
| _wait_healthy( | ||
| w.url, | ||
| sglang_config["timeout"], | ||
| label=f"sglang worker {w.url}", | ||
| proc=w.proc, | ||
| ) | ||
| except (RuntimeError, TimeoutError) as exc: | ||
| # Collect stderr from the first failed worker for diagnostics | ||
| stderr_snippet = "" | ||
| for p in procs: | ||
| if p.poll() is not None and p.stderr: | ||
| try: | ||
| stderr_snippet = p.stderr.read(2048).decode( | ||
| "utf-8", errors="replace" | ||
| ) | ||
| except Exception: | ||
| pass | ||
| if stderr_snippet: | ||
| break | ||
| for p in procs: | ||
| _kill_proc(p) | ||
| pytest.skip( | ||
| f"sglang worker failed to start: {exc}" | ||
| + (f"\nstderr: {stderr_snippet[:500]}" if stderr_snippet else "") | ||
| ) | ||
|
|
||
| yield workers | ||
|
|
||
| for p in procs: | ||
| _kill_proc(p) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def router_url(sglang_workers): | ||
| """Launch a real router connected to real sglang workers.""" | ||
| port = _find_free_port() | ||
| worker_urls = [w.url for w in sglang_workers] | ||
| cmd = [ | ||
| PYTHON, | ||
| "-m", | ||
| "sglang_diffusion_routing", | ||
| "--host", | ||
| "127.0.0.1", | ||
| "--port", | ||
| str(port), | ||
| "--worker-urls", | ||
| *worker_urls, | ||
| "--routing-algorithm", | ||
| "least-request", | ||
| "--log-level", | ||
| "warning", | ||
| ] | ||
| proc = subprocess.Popen( | ||
| cmd, | ||
| env=_env(), | ||
| start_new_session=True, | ||
| stdout=subprocess.PIPE, | ||
| stderr=subprocess.PIPE, | ||
| ) | ||
| url = f"http://127.0.0.1:{port}" | ||
| try: | ||
| _wait_healthy(url, 30, label="router", proc=proc) | ||
| except Exception: | ||
| _kill_proc(proc) | ||
| raise | ||
|
|
||
| yield url | ||
| _kill_proc(proc) | ||
|
|
||
|
|
||
| # -- Tests ------------------------------------------------------------------ | ||
|
|
||
|
|
||
| class TestSglangHealth: | ||
| def test_router_healthy(self, router_url): | ||
| r = httpx.get(f"{router_url}/health", timeout=10.0) | ||
| assert r.status_code == 200 | ||
| assert r.json()["status"] == "healthy" | ||
|
|
||
| def test_workers_listed(self, router_url, sglang_workers): | ||
| urls = httpx.get(f"{router_url}/list_workers", timeout=10.0).json()["urls"] | ||
| assert len(urls) == len(sglang_workers) | ||
|
|
||
|
|
||
| class TestSglangImageGeneration: | ||
| def test_b64_json_generates_real_image(self, router_url, sglang_config): | ||
| """Generate a real image through sglang and verify it's a valid PNG.""" | ||
| r = httpx.post( | ||
| f"{router_url}/generate", | ||
| json={ | ||
| "model": sglang_config["model"], | ||
| "prompt": "a simple red circle on white background", | ||
| "num_images": 1, | ||
| "response_format": "b64_json", | ||
| }, | ||
| timeout=120.0, | ||
| ) | ||
| assert r.status_code == 200 | ||
| body = r.json() | ||
| assert "data" in body | ||
| assert len(body["data"]) == 1 | ||
|
|
||
| img_bytes = base64.b64decode(body["data"][0]["b64_json"]) | ||
| # Verify PNG magic bytes | ||
| assert img_bytes[:8] == b"\x89PNG\r\n\x1a\n" | ||
| # Real image should be substantially larger than a 1x1 pixel | ||
| assert len(img_bytes) > 1000 | ||
|
|
||
| def test_multiple_images(self, router_url, sglang_config): | ||
| r = httpx.post( | ||
| f"{router_url}/generate", | ||
| json={ | ||
| "model": sglang_config["model"], | ||
| "prompt": "a blue square", | ||
| "num_images": 2, | ||
| "response_format": "b64_json", | ||
| }, | ||
| timeout=120.0, | ||
| ) | ||
| assert r.status_code == 200 | ||
| assert len(r.json()["data"]) == 2 | ||
|
|
||
| def test_url_format(self, router_url, sglang_config): | ||
| """Generate with response_format=url (requires worker file storage).""" | ||
| r = httpx.post( | ||
| f"{router_url}/generate", | ||
| json={ | ||
| "model": sglang_config["model"], | ||
| "prompt": "a green triangle", | ||
| "num_images": 1, | ||
| "response_format": "url", | ||
| }, | ||
| timeout=120.0, | ||
| ) | ||
| # url format may require cloud storage config — accept either success or | ||
| # a clear error about storage, not a crash | ||
| assert r.status_code in (200, 400, 500) | ||
| if r.status_code == 200: | ||
| assert "url" in r.json()["data"][0] | ||
|
|
||
|
|
||
| class TestSglangLoadBalancing: | ||
| def test_requests_distributed(self, router_url, sglang_workers, sglang_config): | ||
| """With multiple workers, requests should be distributed.""" | ||
| if len(sglang_workers) < 2: | ||
| pytest.skip("Need at least 2 workers for load balancing test") | ||
|
|
||
| for _ in range(4): | ||
| r = httpx.post( | ||
| f"{router_url}/generate", | ||
| json={ | ||
| "model": sglang_config["model"], | ||
| "prompt": "test", | ||
| "num_images": 1, | ||
| "response_format": "b64_json", | ||
| }, | ||
| timeout=120.0, | ||
| ) | ||
| assert r.status_code == 200 | ||
|
|
||
| # Verify health shows all workers still alive | ||
| health = httpx.get(f"{router_url}/health_workers", timeout=10.0).json() | ||
| assert all(not w["is_dead"] for w in health["workers"]) | ||
|
|
||
|
|
||
| class TestSglangProxy: | ||
| def test_get_model_info(self, router_url): | ||
| """Proxy should forward GET requests to worker.""" | ||
| r = httpx.get(f"{router_url}/get_model_info", timeout=30.0) | ||
| # sglang workers expose /get_model_info | ||
| assert r.status_code in (200, 404) | ||
|
|
||
|
|
||
| class TestSglangVideoEndpoint: | ||
| def test_generate_video_rejects_image_only_workers(self, router_url): | ||
| """Image-only workers (e.g. Qwen/Qwen-Image) should return 400 for /generate_video.""" | ||
| r = httpx.post( | ||
| f"{router_url}/generate_video", | ||
| json={"prompt": "a walking cat", "num_frames": 8}, | ||
| timeout=10.0, | ||
| ) | ||
| assert r.status_code == 400 | ||
| body = r.json() | ||
| assert "error" in body | ||
| assert body["error"] # non-empty error message |
There was a problem hiding this comment.
The addition of test_e2e_sglang.py is a significant improvement to the test suite. These end-to-end tests with real SGLang diffusion workers provide high confidence in the system's integration and functionality, especially given the GPU requirements and model loading complexities. The use of fixtures for worker and router setup is well-implemented.
| #!/usr/bin/env python3 | ||
| """ | ||
| Fake sglang diffusion worker for e2e testing. | ||
|
|
||
| Implements the same HTTP API contract as a real sglang diffusion worker, | ||
| but returns canned responses without any GPU or model dependencies. | ||
|
|
||
| Usage: | ||
| python tests/unit/fake_worker.py --port 19000 | ||
| python tests/unit/fake_worker.py --port 19000 --fail-rate 0.5 | ||
| python tests/unit/fake_worker.py --port 19000 --latency 0.2 | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import argparse | ||
| import asyncio | ||
| import base64 | ||
| import random | ||
| import time | ||
|
|
||
| import uvicorn | ||
| from fastapi import FastAPI, Request | ||
| from fastapi.responses import JSONResponse | ||
|
|
||
| # 1x1 red PNG pixel | ||
| _TINY_PNG = base64.b64decode( | ||
| "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4" | ||
| "2mP8z8BQDwADhQGAWjR9awAAAABJRU5ErkJggg==" | ||
| ) | ||
| _TINY_PNG_B64 = base64.b64encode(_TINY_PNG).decode() | ||
|
|
||
|
|
||
| def create_app( | ||
| fail_rate: float = 0.0, | ||
| latency: float = 0.0, | ||
| worker_id: str = "fake-worker", | ||
| task_type: str = "T2V", | ||
| ) -> FastAPI: | ||
| app = FastAPI() | ||
| request_count = {"total": 0, "generate": 0, "video": 0, "weights": 0} | ||
|
|
||
| @app.get("/health") | ||
| async def health(): | ||
| return {"status": "ok", "worker_id": worker_id} | ||
|
|
||
| @app.get("/v1/models") | ||
| async def list_models(): | ||
| return { | ||
| "object": "list", | ||
| "data": [ | ||
| { | ||
| "id": "fake-model", | ||
| "task_type": task_type, | ||
| } | ||
| ], | ||
| } | ||
|
|
||
| @app.post("/v1/images/generations") | ||
| async def generate_image(request: Request): | ||
| request_count["total"] += 1 | ||
| request_count["generate"] += 1 | ||
|
|
||
| if latency > 0: | ||
| await asyncio.sleep(latency) | ||
|
|
||
| if fail_rate > 0 and random.random() < fail_rate: | ||
| return JSONResponse( | ||
| status_code=500, | ||
| content={"detail": "Simulated worker failure"}, | ||
| ) | ||
|
|
||
| body = await request.json() | ||
| response_format = body.get("response_format", "url") | ||
| prompt = body.get("prompt", "") | ||
| n = body.get("n", body.get("num_images", 1)) | ||
|
|
||
| data = [] | ||
| for i in range(n): | ||
| if response_format == "b64_json": | ||
| data.append( | ||
| { | ||
| "b64_json": _TINY_PNG_B64, | ||
| "revised_prompt": prompt, | ||
| "index": i, | ||
| } | ||
| ) | ||
| else: | ||
| data.append( | ||
| { | ||
| "url": f"http://localhost/files/img_{request_count['generate']:04d}_{i}.png", | ||
| "revised_prompt": prompt, | ||
| "index": i, | ||
| } | ||
| ) | ||
|
|
||
| return { | ||
| "created": int(time.time()), | ||
| "data": data, | ||
| "model": body.get("model", "fake-model"), | ||
| "worker_id": worker_id, | ||
| } | ||
|
|
||
| @app.post("/v1/videos") | ||
| async def generate_video(request: Request): | ||
| request_count["total"] += 1 | ||
| request_count["video"] += 1 | ||
|
|
||
| if latency > 0: | ||
| await asyncio.sleep(latency) | ||
|
|
||
| body = await request.json() | ||
| prompt = body.get("prompt", "") | ||
|
|
||
| return { | ||
| "created": int(time.time()), | ||
| "data": [ | ||
| { | ||
| "url": f"http://localhost/files/vid_{request_count['video']:04d}.mp4", | ||
| "revised_prompt": prompt, | ||
| } | ||
| ], | ||
| "model": body.get("model", "fake-model"), | ||
| "worker_id": worker_id, | ||
| } | ||
|
|
||
| @app.post("/update_weights_from_disk") | ||
| async def update_weights(request: Request): | ||
| request_count["total"] += 1 | ||
| request_count["weights"] += 1 | ||
| body = await request.json() | ||
| return { | ||
| "ok": True, | ||
| "model_path": body.get("model_path", ""), | ||
| "worker_id": worker_id, | ||
| } | ||
|
|
||
| @app.get("/stats") | ||
| async def stats(): | ||
| """Test helper: return request counts.""" | ||
| return request_count | ||
|
|
||
| return app | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser(description="Fake sglang diffusion worker") | ||
| parser.add_argument("--host", default="127.0.0.1") | ||
| parser.add_argument("--port", type=int, required=True) | ||
| parser.add_argument("--fail-rate", type=float, default=0.0) | ||
| parser.add_argument("--latency", type=float, default=0.0) | ||
| parser.add_argument("--worker-id", type=str, default=None) | ||
| parser.add_argument("--task-type", type=str, default="T2V") | ||
| args = parser.parse_args() | ||
|
|
||
| worker_id = args.worker_id or f"fake-worker-{args.port}" | ||
| app = create_app( | ||
| fail_rate=args.fail_rate, | ||
| latency=args.latency, | ||
| worker_id=worker_id, | ||
| task_type=args.task_type, | ||
| ) | ||
| uvicorn.run(app, host=args.host, port=args.port, log_level="warning") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
There was a problem hiding this comment.
The fake_worker.py script is an excellent addition for unit and fake E2E testing. It provides a lightweight, mockable HTTP API that mimics a real SGLang diffusion worker, allowing for comprehensive testing without GPU dependencies. The configurable fail_rate and latency are particularly useful for simulating various worker behaviors.
| created: list[DiffusionRouter] = [] | ||
|
|
||
| def _create(workers=None, dead=None, **kw): | ||
| r = DiffusionRouter(_make_args(**kw)) | ||
| if workers is not None: | ||
| r.worker_request_counts = dict(workers) | ||
| r.worker_failure_counts = {u: 0 for u in workers} | ||
| if dead: | ||
| router.dead_workers = set(dead) | ||
| created_routers.append(router) | ||
| return router | ||
| r.dead_workers = set(dead) | ||
| created.append(r) | ||
| return r | ||
|
|
||
| yield _create | ||
| for r in created: | ||
| asyncio.run(r.client.aclose()) |
There was a problem hiding this comment.
The refactoring of the router_factory fixture to use a created list and asyncio.run(r.client.aclose()) in the teardown ensures that all httpx.AsyncClient instances are properly closed after each test, preventing resource leaks and potential issues with open connections. This is a significant improvement for test reliability.
| class TestRegisterWorker: | ||
| def test_registers_and_deduplicates(self, router_factory): | ||
| r = router_factory() | ||
| r.register_worker("http://w1:8000") | ||
| r.register_worker("http://w1:8000") | ||
| r.register_worker("http://w2:9000") | ||
| assert len(r.worker_request_counts) == 2 | ||
|
|
||
| def test_registered_worker_is_routable(self, router_factory): | ||
| r = router_factory() | ||
| r.register_worker("http://w1:8000") | ||
| assert r._select_worker_by_routing() == "http://w1:8000" | ||
|
|
||
| def test_rejects_blocked_host(self, router_factory): | ||
| r = router_factory() | ||
| with pytest.raises(ValueError, match="blocked"): | ||
| r.register_worker("http://169.254.169.254:80") | ||
| assert len(r.worker_request_counts) == 0 | ||
|
|
||
| def test_dead_then_register_new_restores_routing(self, router_factory): | ||
| r = router_factory( | ||
| {"http://w1:8000": 0}, | ||
| dead={"http://w1:8000"}, | ||
| ) | ||
| with pytest.raises(RuntimeError): | ||
| r._select_worker_by_routing() | ||
| r.register_worker("http://w2:9000") | ||
| assert r._select_worker_by_routing() == "http://w2:9000" |
There was a problem hiding this comment.
The new TestRegisterWorker class with its dedicated tests for registration, deduplication, and blocked hosts significantly improves the coverage and robustness of the worker registration logic. The test for dead_then_register_new_restores_routing is particularly important for ensuring correct behavior when workers recover or new ones are added after failures.
| class TestBuildProxyResponse: | ||
| def test_small_json(self, router_factory): | ||
| r = router_factory() | ||
| content = json.dumps({"key": "value"}).encode() | ||
| resp = r._build_proxy_response( | ||
| content, 200, {"content-type": "application/json"} | ||
| ) | ||
| assert response.status_code == 502 | ||
| assert router.worker_request_counts["http://w1:8000"] == 0 | ||
|
|
||
| def test_register_worker_normalizes_duplicate_urls(self, router_factory): | ||
| router = router_factory({}) | ||
| router.register_worker("http://LOCALHOST:10090/") | ||
| router.register_worker("http://localhost:10090") | ||
| assert list(router.worker_request_counts.keys()) == ["http://localhost:10090"] | ||
|
|
||
| def test_register_worker_rejects_metadata_host(self, router_factory): | ||
| router = router_factory({}) | ||
| with pytest.raises(ValueError, match="host is blocked"): | ||
| router.register_worker("http://169.254.169.254:80") | ||
|
|
||
| def test_broadcast_to_workers_collects_per_worker_results(self, router_factory): | ||
| router = router_factory({"http://w1:8000": 0, "http://w2:8000": 0}) | ||
|
|
||
| class FakeResponse: | ||
| def __init__(self, status_code: int, body: dict): | ||
| self.status_code = status_code | ||
| self._body = body | ||
|
|
||
| async def aread(self) -> bytes: | ||
| return json.dumps(self._body).encode("utf-8") | ||
|
|
||
| responses = { | ||
| "http://w1:8000/update_weights_from_disk": FakeResponse(200, {"ok": True}), | ||
| "http://w2:8000/update_weights_from_disk": FakeResponse( | ||
| 500, {"error": "bad worker"} | ||
| ), | ||
| assert json.loads(resp.body) == {"key": "value"} | ||
| assert resp.status_code == 200 | ||
|
|
||
| def test_large_json_returns_raw(self, router_factory): | ||
| r = router_factory() | ||
| big = json.dumps({"data": "x" * (300 * 1024)}).encode() | ||
| resp = r._build_proxy_response(big, 200, {"content-type": "application/json"}) | ||
| assert resp.body == big | ||
|
|
||
| def test_preserves_status_code(self, router_factory): | ||
| r = router_factory() | ||
| content = json.dumps({"error": "not found"}).encode() | ||
| resp = r._build_proxy_response( | ||
| content, 404, {"content-type": "application/json"} | ||
| ) | ||
| assert resp.status_code == 404 |
| class TestSanitizeResponseHeaders: | ||
| def test_removes_hop_by_hop_and_encoding(self): | ||
| headers = { | ||
| "content-type": "application/json", | ||
| "connection": "keep-alive", | ||
| "transfer-encoding": "chunked", | ||
| "content-length": "1234", | ||
| "content-encoding": "gzip", | ||
| "x-custom": "value", | ||
| } | ||
| result = DiffusionRouter._sanitize_response_headers(headers) | ||
| assert set(result.keys()) == {"content-type", "x-custom"} | ||
|
|
||
|
|
||
| class TestTryDecodeJson: | ||
| def test_valid_json(self): | ||
| assert DiffusionRouter._try_decode_json(b'{"a": 1}') == {"a": 1} | ||
|
|
||
| async def fake_post(url, content, headers): | ||
| del content, headers | ||
| return responses[url] | ||
| def test_invalid_json_returns_raw(self): | ||
| result = DiffusionRouter._try_decode_json(b"not json") | ||
| assert "raw" in result |
| class TestConstructor: | ||
| def test_default_algorithm(self): | ||
| args = Namespace( | ||
| host="127.0.0.1", port=30080, max_connections=100, timeout=120.0 | ||
| ) | ||
| assert len(result) == 2 | ||
| assert {item["worker_url"] for item in result} == { | ||
| "http://w1:8000", | ||
| "http://w2:8000", | ||
| } | ||
| r = DiffusionRouter(args) | ||
| try: | ||
| assert r.routing_algorithm == "least-request" | ||
| finally: | ||
| asyncio.run(r.client.aclose()) | ||
|
|
||
| def test_none_timeout_defaults_to_120(self): | ||
| args = Namespace( | ||
| host="127.0.0.1", | ||
| port=30080, | ||
| max_connections=100, | ||
| timeout=None, | ||
| routing_algorithm="least-request", | ||
| ) | ||
| r = DiffusionRouter(args) | ||
| try: | ||
| assert r.client.timeout.connect == 120.0 | ||
| finally: | ||
| asyncio.run(r.client.aclose()) | ||
|
|
||
| def test_initial_state_empty(self): | ||
| args = _make_args() | ||
| r = DiffusionRouter(args) | ||
| try: | ||
| assert r.worker_request_counts == {} | ||
| assert r.dead_workers == set() | ||
| finally: | ||
| asyncio.run(r.client.aclose()) | ||
|
|
||
| def test_all_routes_registered(self): | ||
| args = _make_args() | ||
| r = DiffusionRouter(args) | ||
| try: | ||
| routes = [route.path for route in r.app.routes] | ||
| for path in [ | ||
| "/add_worker", | ||
| "/list_workers", | ||
| "/health", | ||
| "/health_workers", | ||
| "/generate", | ||
| "/generate_video", | ||
| "/update_weights_from_disk", | ||
| "/{path:path}", | ||
| ]: | ||
| assert path in routes | ||
| finally: | ||
| asyncio.run(r.client.aclose()) |
There was a problem hiding this comment.
fix #5 #26 and #19