Skip to content

[WIP] ci: add more testcase#25

Merged
zhaochenyang20 merged 3 commits intozhaochenyang20:mainfrom
alphabetc1:ci/add_more_testcase
Feb 23, 2026
Merged

[WIP] ci: add more testcase#25
zhaochenyang20 merged 3 commits intozhaochenyang20:mainfrom
alphabetc1:ci/add_more_testcase

Conversation

@alphabetc1
Copy link
Copy Markdown
Collaborator

@alphabetc1 alphabetc1 commented Feb 23, 2026

fix #5 #26 and #19

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @alphabetc1, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly strengthens the testing framework for the sglang-diffusion-routing project. It introduces robust end-to-end tests that validate the router's behavior with both real and simulated SGLang workers, ensuring reliable operation of load balancing, request routing, and worker management. The changes also include minor refactorings to improve the router's startup logic and error handling, making the system more resilient and easier to maintain.

Highlights

  • Enhanced Testing Infrastructure: Introduced comprehensive end-to-end (E2E) tests for the sglang-diffusion-routing, covering real SGLang workers (GPU required) and a new CPU-only fake E2E testing setup.
  • Refactored Worker Video Support: Moved the asynchronous probing of worker video capabilities from the CLI's router server startup to the router's background health check, improving startup efficiency.
  • Improved Router Robustness: Added a check in the generate_video endpoint to return a 503 error if no workers are registered, preventing potential issues.
  • Updated Pytest Configuration: Modified pyproject.toml to include the new E2E test paths and added a conftest.py to manage local src import precedence for testing.
  • Refined Unit Tests: Updated existing unit tests for CLI argument parsing and DiffusionRouter core functionality, making them more focused and adding new tests for worker registration, response building, and router initialization.
Changelog
  • development.md
    • Updated test instructions to include separate commands for CPU-only unit tests and GPU-required real E2E tests.
  • pyproject.toml
    • Configured pytest to discover tests in both 'tests/unit' and 'tests/e2e' directories.
  • src/sglang_diffusion_routing/cli/main.py
    • Removed the direct asyncio import and the immediate worker video support refresh logic from _run_router_server, delegating this to the router's background tasks.
  • src/sglang_diffusion_routing/router/diffusion_router.py
    • Moved the probing of worker video capabilities to the _start_background_health_check method, allowing it to run asynchronously within the event loop.
    • Added a check in generate_video to return a 503 status code if no workers are registered in the pool.
  • tests/conftest.py
    • Added a pytest configuration file to ensure the local 'src' directory is prioritized in the Python path for test execution.
  • tests/e2e/test_e2e_sglang.py
    • Added a new file containing end-to-end tests for the sglang diffusion router, covering health, image/video generation, load balancing, and proxying with real SGLang workers.
  • tests/unit/fake_worker.py
    • Added a new file implementing a fake SGLang diffusion worker, designed for CPU-only end-to-end testing scenarios.
  • tests/unit/test_cli.py
    • Updated the module docstring to reflect that CLI integration is now covered by e2e tests.
    • Renamed test_parses_worker_urls to test_full_args and expanded it to validate parsing of new CLI arguments such as timeout, max-connections, health-check-interval, and health-check-failure-threshold.
    • Added tests to ensure invalid routing algorithms are rejected and all valid algorithms are accepted.
    • Removed test_run_cli_calls_router_runner as its functionality is now covered by e2e tests.
  • tests/unit/test_diffusion_router.py
    • Updated the module docstring to clarify that tests directly call real code without mocks.
    • Renamed make_router_args to _make_args for internal use.
    • Refactored the router_factory fixture to simplify router creation and ensure proper client closure.
    • Added test_picks_min_load to TestLeastRequest to verify selection of the worker with minimum load.
    • Renamed test_cycles_workers to test_cycles_in_order and test_excludes_dead_workers to test_skips_dead_workers in TestRoundRobin.
    • Renamed TestRandom to TestRandomRouting and test_selects_from_valid_workers to test_covers_all_workers.
    • Renamed TestErrorCases to TestRoutingEdgeCases and its tests to test_no_workers_raises and test_all_dead_raises.
    • Removed TestDefaults and TestRegressions classes, integrating relevant tests into TestConstructor or removing redundant ones.
    • Added new test classes: TestRegisterWorker for worker registration, TestBuildProxyResponse for response building, TestSanitizeResponseHeaders for header cleaning, TestTryDecodeJson for JSON decoding, and TestConstructor for router initialization and route registration.
  • tests/unit/test_router_endpoints.py
    • Removed this file as its test cases are now covered by the new fake end-to-end tests in tests/unit/test_fake_e2e.py.
Ignored Files
  • Ignored by pattern: .github/workflows/** (1)
    • .github/workflows/pr-test.yml
Activity
  • The pull request is currently marked as 'Work In Progress' (WIP), indicating that the author is still actively developing or refining the changes.
  • The pull request description links to a GitHub issue ([Question] Tests is a little bit fake #5), suggesting that these changes are part of a larger effort to address a specific problem or feature request.
  • The extensive changes across multiple test files and the addition of new test types (E2E, fake E2E) show a significant focus on improving the project's testing coverage and reliability.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces new test cases for both unit and end-to-end scenarios, significantly improving test coverage. It also refactors the worker video support probing logic to occur within the background health check, rather than during router startup, which is a good improvement for robustness. Additionally, a check for no registered workers has been added to the generate_video endpoint. The changes are well-structured and enhance the overall stability and testability of the application.

Comment on lines +396 to +400
if not self.worker_request_counts:
return JSONResponse(
status_code=503,
content={"error": "No workers registered in the pool"},
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Adding a check for worker_request_counts before attempting to route video generation requests is crucial. This prevents errors when no workers are registered and provides a clear 503 Service Unavailable response to the client.

Comment on lines 28 to 30
for url in worker_urls:
normalized_url = router.normalize_worker_url(url)
router.register_worker(normalized_url)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The removal of the refresh_tasks list and the _refresh_all_worker_video_support async function simplifies the router startup logic in _run_router_server. This logic has been correctly moved to the _start_background_health_check method in diffusion_router.py.

Comment on lines +72 to +80
# Probe video capability for pre-registered workers in the running event loop.
unknown_workers = [
url for url, support in self.worker_video_support.items() if support is None
]
if unknown_workers:
await asyncio.gather(
*(self.refresh_worker_video_support(url) for url in unknown_workers),
return_exceptions=True,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Moving the video capability probing for pre-registered workers into _start_background_health_check is a good architectural improvement. This ensures that the probing happens asynchronously within the running event loop, preventing blocking during router startup and making the process more robust.

Comment on lines +1 to +399
"""
End-to-end tests with real sglang diffusion workers.

Requires:
- sglang installed with diffusion support: pip install "sglang[diffusion]"
- At least 1 GPU available
- Model weights accessible (downloads on first run)

These tests are SKIPPED automatically when sglang or GPU is not available.
To run explicitly:

pytest tests/e2e/test_e2e_sglang.py -v -s

Override model/GPU config via environment variables:
SGLANG_TEST_MODEL Model path (default: Qwen/Qwen-Image)
SGLANG_TEST_NUM_GPUS GPUs per worker (default: 1)
SGLANG_TEST_NUM_WORKERS Number of workers (default: 2)
SGLANG_TEST_TIMEOUT Startup timeout in seconds (default: 600)
"""

from __future__ import annotations

import base64
import os
import shutil
import signal
import socket
import subprocess
import sys
import time
from pathlib import Path

import httpx
import pytest

REPO_ROOT = Path(__file__).resolve().parents[2]
PYTHON = sys.executable

DEFAULT_MODEL = "Qwen/Qwen-Image"
DEFAULT_NUM_GPUS = 1
DEFAULT_NUM_WORKERS = 2
DEFAULT_TIMEOUT = 600 # sglang model loading can be slow


def _has_sglang() -> bool:
try:
import sglang # noqa: F401

return True
except ImportError:
return False


def _gpu_count() -> int:
try:
import torch

return torch.cuda.device_count()
except Exception:
return 0


def _get_env(key: str, default: str) -> str:
return os.environ.get(key, default)


def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]


def _env() -> dict[str, str]:
env = os.environ.copy()
src = str(REPO_ROOT / "src")
old = env.get("PYTHONPATH", "")
env["PYTHONPATH"] = f"{src}:{old}" if old else src
return env


def _wait_healthy(
url: str, timeout: float, label: str = "", proc: subprocess.Popen | None = None
) -> None:
deadline = time.monotonic() + timeout
last_log = 0.0
while time.monotonic() < deadline:
if proc is not None and proc.poll() is not None:
raise RuntimeError(
f"{label} process exited with code {proc.returncode} during startup"
)
try:
r = httpx.get(f"{url}/health", timeout=5.0)
if r.status_code == 200:
return
except httpx.HTTPError:
pass
now = time.monotonic()
if now - last_log >= 30:
elapsed = now - (deadline - timeout)
print(f" Waiting for {label}... ({elapsed:.0f}s)", flush=True)
last_log = now
time.sleep(2)
raise TimeoutError(f"{label} at {url} not healthy within {timeout}s")


def _kill_proc(proc: subprocess.Popen) -> None:
if proc.poll() is not None:
return
try:
os.killpg(proc.pid, signal.SIGTERM)
except (ProcessLookupError, PermissionError):
pass
try:
proc.wait(timeout=15)
except subprocess.TimeoutExpired:
try:
os.killpg(proc.pid, signal.SIGKILL)
except (ProcessLookupError, PermissionError):
pass
proc.wait(timeout=5)


# -- Skip conditions -------------------------------------------------------

_skip_no_sglang = pytest.mark.skipif(
not _has_sglang() or shutil.which("sglang") is None,
reason="sglang not installed or 'sglang' CLI not in PATH",
)
_skip_no_gpu = pytest.mark.skipif(
_gpu_count() == 0,
reason="No GPU available",
)

pytestmark = [_skip_no_sglang, _skip_no_gpu]


# -- Fixtures ---------------------------------------------------------------


class SglangWorker:
def __init__(self, proc: subprocess.Popen, url: str):
self.proc = proc
self.url = url


@pytest.fixture(scope="module")
def sglang_config():
model = _get_env("SGLANG_TEST_MODEL", DEFAULT_MODEL)
num_gpus = int(_get_env("SGLANG_TEST_NUM_GPUS", str(DEFAULT_NUM_GPUS)))
num_workers = int(_get_env("SGLANG_TEST_NUM_WORKERS", str(DEFAULT_NUM_WORKERS)))
timeout = int(_get_env("SGLANG_TEST_TIMEOUT", str(DEFAULT_TIMEOUT)))

gpus_available = _gpu_count()
needed = num_workers * num_gpus
if gpus_available < needed:
pytest.skip(
f"Need {needed} GPUs ({num_workers} workers x {num_gpus} GPUs), "
f"only {gpus_available} available"
)

return {
"model": model,
"num_gpus": num_gpus,
"num_workers": num_workers,
"timeout": timeout,
}


@pytest.fixture(scope="module")
def sglang_workers(sglang_config):
"""Launch real sglang diffusion worker processes."""
workers = []
procs = []
env = _env()
gpu_pool = list(range(_gpu_count()))

for i in range(sglang_config["num_workers"]):
port = _find_free_port()
gpu_start = i * sglang_config["num_gpus"]
gpu_end = gpu_start + sglang_config["num_gpus"]
gpu_ids = ",".join(str(gpu_pool[g]) for g in range(gpu_start, gpu_end))

worker_env = dict(env)
worker_env["CUDA_VISIBLE_DEVICES"] = gpu_ids

cmd = [
"sglang",
"serve",
"--model-path",
sglang_config["model"],
"--num-gpus",
str(sglang_config["num_gpus"]),
"--host",
"127.0.0.1",
"--port",
str(port),
]

print(
f"\n[sglang-test] Starting worker {i} on port {port} (GPU: {gpu_ids})",
flush=True,
)
proc = subprocess.Popen(
cmd,
env=worker_env,
start_new_session=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
procs.append(proc)
workers.append(SglangWorker(proc, f"http://127.0.0.1:{port}"))

try:
for w in workers:
_wait_healthy(
w.url,
sglang_config["timeout"],
label=f"sglang worker {w.url}",
proc=w.proc,
)
except (RuntimeError, TimeoutError) as exc:
# Collect stderr from the first failed worker for diagnostics
stderr_snippet = ""
for p in procs:
if p.poll() is not None and p.stderr:
try:
stderr_snippet = p.stderr.read(2048).decode(
"utf-8", errors="replace"
)
except Exception:
pass
if stderr_snippet:
break
for p in procs:
_kill_proc(p)
pytest.skip(
f"sglang worker failed to start: {exc}"
+ (f"\nstderr: {stderr_snippet[:500]}" if stderr_snippet else "")
)

yield workers

for p in procs:
_kill_proc(p)


@pytest.fixture(scope="module")
def router_url(sglang_workers):
"""Launch a real router connected to real sglang workers."""
port = _find_free_port()
worker_urls = [w.url for w in sglang_workers]
cmd = [
PYTHON,
"-m",
"sglang_diffusion_routing",
"--host",
"127.0.0.1",
"--port",
str(port),
"--worker-urls",
*worker_urls,
"--routing-algorithm",
"least-request",
"--log-level",
"warning",
]
proc = subprocess.Popen(
cmd,
env=_env(),
start_new_session=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
url = f"http://127.0.0.1:{port}"
try:
_wait_healthy(url, 30, label="router", proc=proc)
except Exception:
_kill_proc(proc)
raise

yield url
_kill_proc(proc)


# -- Tests ------------------------------------------------------------------


class TestSglangHealth:
def test_router_healthy(self, router_url):
r = httpx.get(f"{router_url}/health", timeout=10.0)
assert r.status_code == 200
assert r.json()["status"] == "healthy"

def test_workers_listed(self, router_url, sglang_workers):
urls = httpx.get(f"{router_url}/list_workers", timeout=10.0).json()["urls"]
assert len(urls) == len(sglang_workers)


class TestSglangImageGeneration:
def test_b64_json_generates_real_image(self, router_url, sglang_config):
"""Generate a real image through sglang and verify it's a valid PNG."""
r = httpx.post(
f"{router_url}/generate",
json={
"model": sglang_config["model"],
"prompt": "a simple red circle on white background",
"num_images": 1,
"response_format": "b64_json",
},
timeout=120.0,
)
assert r.status_code == 200
body = r.json()
assert "data" in body
assert len(body["data"]) == 1

img_bytes = base64.b64decode(body["data"][0]["b64_json"])
# Verify PNG magic bytes
assert img_bytes[:8] == b"\x89PNG\r\n\x1a\n"
# Real image should be substantially larger than a 1x1 pixel
assert len(img_bytes) > 1000

def test_multiple_images(self, router_url, sglang_config):
r = httpx.post(
f"{router_url}/generate",
json={
"model": sglang_config["model"],
"prompt": "a blue square",
"num_images": 2,
"response_format": "b64_json",
},
timeout=120.0,
)
assert r.status_code == 200
assert len(r.json()["data"]) == 2

def test_url_format(self, router_url, sglang_config):
"""Generate with response_format=url (requires worker file storage)."""
r = httpx.post(
f"{router_url}/generate",
json={
"model": sglang_config["model"],
"prompt": "a green triangle",
"num_images": 1,
"response_format": "url",
},
timeout=120.0,
)
# url format may require cloud storage config — accept either success or
# a clear error about storage, not a crash
assert r.status_code in (200, 400, 500)
if r.status_code == 200:
assert "url" in r.json()["data"][0]


class TestSglangLoadBalancing:
def test_requests_distributed(self, router_url, sglang_workers, sglang_config):
"""With multiple workers, requests should be distributed."""
if len(sglang_workers) < 2:
pytest.skip("Need at least 2 workers for load balancing test")

for _ in range(4):
r = httpx.post(
f"{router_url}/generate",
json={
"model": sglang_config["model"],
"prompt": "test",
"num_images": 1,
"response_format": "b64_json",
},
timeout=120.0,
)
assert r.status_code == 200

# Verify health shows all workers still alive
health = httpx.get(f"{router_url}/health_workers", timeout=10.0).json()
assert all(not w["is_dead"] for w in health["workers"])


class TestSglangProxy:
def test_get_model_info(self, router_url):
"""Proxy should forward GET requests to worker."""
r = httpx.get(f"{router_url}/get_model_info", timeout=30.0)
# sglang workers expose /get_model_info
assert r.status_code in (200, 404)


class TestSglangVideoEndpoint:
def test_generate_video_rejects_image_only_workers(self, router_url):
"""Image-only workers (e.g. Qwen/Qwen-Image) should return 400 for /generate_video."""
r = httpx.post(
f"{router_url}/generate_video",
json={"prompt": "a walking cat", "num_frames": 8},
timeout=10.0,
)
assert r.status_code == 400
body = r.json()
assert "error" in body
assert body["error"] # non-empty error message
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The addition of test_e2e_sglang.py is a significant improvement to the test suite. These end-to-end tests with real SGLang diffusion workers provide high confidence in the system's integration and functionality, especially given the GPU requirements and model loading complexities. The use of fixtures for worker and router setup is well-implemented.

Comment thread tests/unit/fake_worker.py
Comment on lines +1 to +167
#!/usr/bin/env python3
"""
Fake sglang diffusion worker for e2e testing.

Implements the same HTTP API contract as a real sglang diffusion worker,
but returns canned responses without any GPU or model dependencies.

Usage:
python tests/unit/fake_worker.py --port 19000
python tests/unit/fake_worker.py --port 19000 --fail-rate 0.5
python tests/unit/fake_worker.py --port 19000 --latency 0.2
"""

from __future__ import annotations

import argparse
import asyncio
import base64
import random
import time

import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse

# 1x1 red PNG pixel
_TINY_PNG = base64.b64decode(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4"
"2mP8z8BQDwADhQGAWjR9awAAAABJRU5ErkJggg=="
)
_TINY_PNG_B64 = base64.b64encode(_TINY_PNG).decode()


def create_app(
fail_rate: float = 0.0,
latency: float = 0.0,
worker_id: str = "fake-worker",
task_type: str = "T2V",
) -> FastAPI:
app = FastAPI()
request_count = {"total": 0, "generate": 0, "video": 0, "weights": 0}

@app.get("/health")
async def health():
return {"status": "ok", "worker_id": worker_id}

@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [
{
"id": "fake-model",
"task_type": task_type,
}
],
}

@app.post("/v1/images/generations")
async def generate_image(request: Request):
request_count["total"] += 1
request_count["generate"] += 1

if latency > 0:
await asyncio.sleep(latency)

if fail_rate > 0 and random.random() < fail_rate:
return JSONResponse(
status_code=500,
content={"detail": "Simulated worker failure"},
)

body = await request.json()
response_format = body.get("response_format", "url")
prompt = body.get("prompt", "")
n = body.get("n", body.get("num_images", 1))

data = []
for i in range(n):
if response_format == "b64_json":
data.append(
{
"b64_json": _TINY_PNG_B64,
"revised_prompt": prompt,
"index": i,
}
)
else:
data.append(
{
"url": f"http://localhost/files/img_{request_count['generate']:04d}_{i}.png",
"revised_prompt": prompt,
"index": i,
}
)

return {
"created": int(time.time()),
"data": data,
"model": body.get("model", "fake-model"),
"worker_id": worker_id,
}

@app.post("/v1/videos")
async def generate_video(request: Request):
request_count["total"] += 1
request_count["video"] += 1

if latency > 0:
await asyncio.sleep(latency)

body = await request.json()
prompt = body.get("prompt", "")

return {
"created": int(time.time()),
"data": [
{
"url": f"http://localhost/files/vid_{request_count['video']:04d}.mp4",
"revised_prompt": prompt,
}
],
"model": body.get("model", "fake-model"),
"worker_id": worker_id,
}

@app.post("/update_weights_from_disk")
async def update_weights(request: Request):
request_count["total"] += 1
request_count["weights"] += 1
body = await request.json()
return {
"ok": True,
"model_path": body.get("model_path", ""),
"worker_id": worker_id,
}

@app.get("/stats")
async def stats():
"""Test helper: return request counts."""
return request_count

return app


def main():
parser = argparse.ArgumentParser(description="Fake sglang diffusion worker")
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, required=True)
parser.add_argument("--fail-rate", type=float, default=0.0)
parser.add_argument("--latency", type=float, default=0.0)
parser.add_argument("--worker-id", type=str, default=None)
parser.add_argument("--task-type", type=str, default="T2V")
args = parser.parse_args()

worker_id = args.worker_id or f"fake-worker-{args.port}"
app = create_app(
fail_rate=args.fail_rate,
latency=args.latency,
worker_id=worker_id,
task_type=args.task_type,
)
uvicorn.run(app, host=args.host, port=args.port, log_level="warning")


if __name__ == "__main__":
main()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fake_worker.py script is an excellent addition for unit and fake E2E testing. It provides a lightweight, mockable HTTP API that mimics a real SGLang diffusion worker, allowing for comprehensive testing without GPU dependencies. The configurable fail_rate and latency are particularly useful for simulating various worker behaviors.

Comment on lines +31 to +45
created: list[DiffusionRouter] = []

def _create(workers=None, dead=None, **kw):
r = DiffusionRouter(_make_args(**kw))
if workers is not None:
r.worker_request_counts = dict(workers)
r.worker_failure_counts = {u: 0 for u in workers}
if dead:
router.dead_workers = set(dead)
created_routers.append(router)
return router
r.dead_workers = set(dead)
created.append(r)
return r

yield _create
for r in created:
asyncio.run(r.client.aclose())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The refactoring of the router_factory fixture to use a created list and asyncio.run(r.client.aclose()) in the teardown ensures that all httpx.AsyncClient instances are properly closed after each test, preventing resource leaks and potential issues with open connections. This is a significant improvement for test reliability.

Comment on lines +159 to +186
class TestRegisterWorker:
def test_registers_and_deduplicates(self, router_factory):
r = router_factory()
r.register_worker("http://w1:8000")
r.register_worker("http://w1:8000")
r.register_worker("http://w2:9000")
assert len(r.worker_request_counts) == 2

def test_registered_worker_is_routable(self, router_factory):
r = router_factory()
r.register_worker("http://w1:8000")
assert r._select_worker_by_routing() == "http://w1:8000"

def test_rejects_blocked_host(self, router_factory):
r = router_factory()
with pytest.raises(ValueError, match="blocked"):
r.register_worker("http://169.254.169.254:80")
assert len(r.worker_request_counts) == 0

def test_dead_then_register_new_restores_routing(self, router_factory):
r = router_factory(
{"http://w1:8000": 0},
dead={"http://w1:8000"},
)
with pytest.raises(RuntimeError):
r._select_worker_by_routing()
r.register_worker("http://w2:9000")
assert r._select_worker_by_routing() == "http://w2:9000"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new TestRegisterWorker class with its dedicated tests for registration, deduplication, and blocked hosts significantly improves the coverage and robustness of the worker registration logic. The test for dead_then_register_new_restores_routing is particularly important for ensuring correct behavior when workers recover or new ones are added after failures.

Comment on lines +192 to +214
class TestBuildProxyResponse:
def test_small_json(self, router_factory):
r = router_factory()
content = json.dumps({"key": "value"}).encode()
resp = r._build_proxy_response(
content, 200, {"content-type": "application/json"}
)
assert response.status_code == 502
assert router.worker_request_counts["http://w1:8000"] == 0

def test_register_worker_normalizes_duplicate_urls(self, router_factory):
router = router_factory({})
router.register_worker("http://LOCALHOST:10090/")
router.register_worker("http://localhost:10090")
assert list(router.worker_request_counts.keys()) == ["http://localhost:10090"]

def test_register_worker_rejects_metadata_host(self, router_factory):
router = router_factory({})
with pytest.raises(ValueError, match="host is blocked"):
router.register_worker("http://169.254.169.254:80")

def test_broadcast_to_workers_collects_per_worker_results(self, router_factory):
router = router_factory({"http://w1:8000": 0, "http://w2:8000": 0})

class FakeResponse:
def __init__(self, status_code: int, body: dict):
self.status_code = status_code
self._body = body

async def aread(self) -> bytes:
return json.dumps(self._body).encode("utf-8")

responses = {
"http://w1:8000/update_weights_from_disk": FakeResponse(200, {"ok": True}),
"http://w2:8000/update_weights_from_disk": FakeResponse(
500, {"error": "bad worker"}
),
assert json.loads(resp.body) == {"key": "value"}
assert resp.status_code == 200

def test_large_json_returns_raw(self, router_factory):
r = router_factory()
big = json.dumps({"data": "x" * (300 * 1024)}).encode()
resp = r._build_proxy_response(big, 200, {"content-type": "application/json"})
assert resp.body == big

def test_preserves_status_code(self, router_factory):
r = router_factory()
content = json.dumps({"error": "not found"}).encode()
resp = r._build_proxy_response(
content, 404, {"content-type": "application/json"}
)
assert resp.status_code == 404
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new TestBuildProxyResponse class with tests for small JSON, large JSON, and status code preservation ensures that the router correctly handles various response types from workers, improving the reliability of the proxy functionality.

Comment on lines +220 to +240
class TestSanitizeResponseHeaders:
def test_removes_hop_by_hop_and_encoding(self):
headers = {
"content-type": "application/json",
"connection": "keep-alive",
"transfer-encoding": "chunked",
"content-length": "1234",
"content-encoding": "gzip",
"x-custom": "value",
}
result = DiffusionRouter._sanitize_response_headers(headers)
assert set(result.keys()) == {"content-type", "x-custom"}


class TestTryDecodeJson:
def test_valid_json(self):
assert DiffusionRouter._try_decode_json(b'{"a": 1}') == {"a": 1}

async def fake_post(url, content, headers):
del content, headers
return responses[url]
def test_invalid_json_returns_raw(self):
result = DiffusionRouter._try_decode_json(b"not json")
assert "raw" in result
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new TestSanitizeResponseHeaders and TestTryDecodeJson classes provide focused unit tests for these static helper methods, ensuring their correctness and improving the overall test coverage of the DiffusionRouter class.

Comment on lines +246 to +297
class TestConstructor:
def test_default_algorithm(self):
args = Namespace(
host="127.0.0.1", port=30080, max_connections=100, timeout=120.0
)
assert len(result) == 2
assert {item["worker_url"] for item in result} == {
"http://w1:8000",
"http://w2:8000",
}
r = DiffusionRouter(args)
try:
assert r.routing_algorithm == "least-request"
finally:
asyncio.run(r.client.aclose())

def test_none_timeout_defaults_to_120(self):
args = Namespace(
host="127.0.0.1",
port=30080,
max_connections=100,
timeout=None,
routing_algorithm="least-request",
)
r = DiffusionRouter(args)
try:
assert r.client.timeout.connect == 120.0
finally:
asyncio.run(r.client.aclose())

def test_initial_state_empty(self):
args = _make_args()
r = DiffusionRouter(args)
try:
assert r.worker_request_counts == {}
assert r.dead_workers == set()
finally:
asyncio.run(r.client.aclose())

def test_all_routes_registered(self):
args = _make_args()
r = DiffusionRouter(args)
try:
routes = [route.path for route in r.app.routes]
for path in [
"/add_worker",
"/list_workers",
"/health",
"/health_workers",
"/generate",
"/generate_video",
"/update_weights_from_disk",
"/{path:path}",
]:
assert path in routes
finally:
asyncio.run(r.client.aclose())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new TestConstructor class provides comprehensive tests for the DiffusionRouter constructor, covering default algorithm, timeout handling, initial state, and route registration. This ensures that the router is initialized correctly under various conditions.

@zhaochenyang20 zhaochenyang20 merged commit 58a0184 into zhaochenyang20:main Feb 23, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Question] Tests is a little bit fake

2 participants