Skip to content

Commit 634aba8

Browse files
Merge pull request #33 from alphabetc1/ci/add_more_testcase
fix: prevent /generate 502 caused by event loop mismatch + add e2e tests
2 parents de7cc3e + 2cd7a92 commit 634aba8

File tree

17 files changed

+1491
-44
lines changed

17 files changed

+1491
-44
lines changed

.github/workflows/cpu-test-api.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ jobs:
2424
python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
2525
2626
- name: Run CPU unit tests
27-
run: pytest tests/unit -v
27+
run: pytest tests/unit tests/integration -v

development.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,17 @@ Run CPU only tests:
1010

1111
```bash
1212
pip install pytest
13-
pytest tests/unit -v
13+
# CPU-only tests (unit + integration)
14+
pytest tests/unit tests/integration -v
15+
16+
# Real E2E tests (GPU required, longer runtime)
17+
pytest tests/e2e/test_e2e_sglang.py -v -s
1418
```
1519

1620
## Benchmark Scripts
1721

1822
Benchmark scripts are available under `tests/benchmarks/diffusion_router/` and are intended for manual runs.
19-
They are not part of default unit test collection (`pytest tests/unit -v`).
23+
They are not part of default unit test collection (`pytest tests/unit tests/integration -v`).
2024

2125
Single benchmark:
2226

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,9 @@ package-dir = { "" = "src" }
4040
where = ["src"]
4141

4242
[tool.pytest.ini_options]
43-
testpaths = ["tests/unit"]
43+
testpaths = ["tests/unit", "tests/integration"]
44+
markers = [
45+
"integration: CPU-only integration tests with real processes.",
46+
"real_e2e: Real e2e tests requiring sglang and GPU.",
47+
]
4448
pythonpath = ["src"]

src/sglang_diffusion_routing/cli/main.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,9 @@ def _run_router_server(
6767
) from exc
6868

6969
worker_urls = list(args.worker_urls or [])
70-
refresh_tasks = []
7170
for url in worker_urls:
7271
normalized_url = router.normalize_worker_url(url)
7372
router.register_worker(normalized_url)
74-
refresh_tasks.append(router.refresh_worker_video_support(normalized_url))
75-
76-
if refresh_tasks:
77-
78-
async def _refresh_all_worker_video_support() -> None:
79-
await asyncio.gather(*refresh_tasks)
80-
81-
asyncio.run(_refresh_all_worker_video_support())
8273

8374
print(f"{log_prefix} starting router on {args.host}:{args.port}", flush=True)
8475
print(

src/sglang_diffusion_routing/launcher/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,20 @@
1010
from collections.abc import Iterable
1111

1212
import httpx
13-
import torch
1413

1514
# TODO (mengyang, shuwen, chenyang): these utils should be clean up.
1615

1716

17+
def _cuda_device_count() -> int:
18+
"""Best-effort CUDA device count without hard torch import at module import."""
19+
try:
20+
import torch
21+
22+
return int(torch.cuda.device_count())
23+
except Exception:
24+
return 0
25+
26+
1827
def infer_connect_host(host: str) -> str:
1928
"""Normalize bind-all addresses to loopback for client connections."""
2029
if host in ("0.0.0.0", "::", "localhost"):
@@ -72,7 +81,7 @@ def resolve_gpu_pool(
7281
if parsed:
7382
return parsed
7483

75-
gpu_count = int(torch.cuda.device_count())
84+
gpu_count = _cuda_device_count()
7685
if gpu_count > 0:
7786
return [str(i) for i in range(gpu_count)]
7887
return None
@@ -116,7 +125,7 @@ def build_gpu_assignments(
116125
gpu_pool = parsed
117126

118127
if gpu_pool is None:
119-
gpu_count = int(torch.cuda.device_count())
128+
gpu_count = _cuda_device_count()
120129
if gpu_count > 0:
121130
gpu_pool = [str(i) for i in range(gpu_count)]
122131

src/sglang_diffusion_routing/router/diffusion_router.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,16 @@ def _setup_routes(self) -> None:
8686
)
8787

8888
async def _start_background_health_check(self) -> None:
89+
# Probe capability for pre-registered workers in the active server loop.
90+
unknown_workers = [
91+
url for url, support in self.worker_video_support.items() if support is None
92+
]
93+
if unknown_workers:
94+
await asyncio.gather(
95+
*(self.refresh_worker_video_support(url) for url in unknown_workers),
96+
return_exceptions=True,
97+
)
98+
8999
if self._health_task is None or self._health_task.done():
90100
self._health_task = asyncio.create_task(self._health_check_loop())
91101

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Pytest configuration: force local src import precedence."""
2+
3+
from __future__ import annotations
4+
5+
import sys
6+
from pathlib import Path
7+
8+
src_str = str(Path(__file__).resolve().parent.parent / "src")
9+
while src_str in sys.path:
10+
sys.path.remove(src_str)
11+
sys.path.insert(0, src_str)

tests/e2e/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)