Skip to content

Commit 2cd7a92

Browse files
committed
reuse utils.py
1 parent 3904852 commit 2cd7a92

File tree

1 file changed

+31
-46
lines changed

1 file changed

+31
-46
lines changed

tests/e2e/test_e2e_sglang.py

Lines changed: 31 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,20 @@
2424
import base64
2525
import os
2626
import shutil
27-
import signal
28-
import socket
2927
import subprocess
3028
import sys
3129
from pathlib import Path
3230

3331
import httpx
3432
import pytest
3533

36-
from sglang_diffusion_routing.launcher.utils import wait_for_health
34+
from sglang_diffusion_routing.launcher.utils import (
35+
build_gpu_assignments,
36+
reserve_available_port,
37+
resolve_gpu_pool,
38+
terminate_all,
39+
wait_for_health,
40+
)
3741

3842
REPO_ROOT = Path(__file__).resolve().parents[2]
3943
PYTHON = sys.executable
@@ -59,22 +63,19 @@ def _has_sglang() -> bool:
5963

6064

6165
def _gpu_count() -> int:
62-
try:
63-
import torch
64-
65-
return torch.cuda.device_count()
66-
except Exception:
67-
return 0
66+
gpu_pool = resolve_gpu_pool(worker_gpu_ids=None, env=os.environ)
67+
return len(gpu_pool) if gpu_pool else 0
6868

6969

7070
def _get_env(key: str, default: str) -> str:
7171
return os.environ.get(key, default)
7272

7373

74-
def _find_free_port() -> int:
75-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
76-
s.bind(("127.0.0.1", 0))
77-
return s.getsockname()[1]
74+
_USED_PORTS: set[int] = set()
75+
76+
77+
def _reserve_local_port(preferred_port: int) -> int:
78+
return reserve_available_port("127.0.0.1", preferred_port, _USED_PORTS)
7879

7980

8081
def _env() -> dict[str, str]:
@@ -97,23 +98,6 @@ def _wait_healthy(
9798
)
9899

99100

100-
def _kill_proc(proc: subprocess.Popen) -> None:
101-
if proc.poll() is not None:
102-
return
103-
try:
104-
os.killpg(proc.pid, signal.SIGTERM)
105-
except (ProcessLookupError, PermissionError):
106-
pass
107-
try:
108-
proc.wait(timeout=15)
109-
except subprocess.TimeoutExpired:
110-
try:
111-
os.killpg(proc.pid, signal.SIGKILL)
112-
except (ProcessLookupError, PermissionError):
113-
pass
114-
proc.wait(timeout=5)
115-
116-
117101
def _read_stderr_snippet(proc: subprocess.Popen, max_bytes: int = 8192) -> str:
118102
if proc.stderr is None:
119103
return ""
@@ -205,13 +189,18 @@ def sglang_workers(sglang_config):
205189
workers = []
206190
procs = []
207191
env = _env()
208-
gpu_pool = list(range(_gpu_count()))
192+
gpu_assignments = build_gpu_assignments(
193+
worker_gpu_ids=None,
194+
num_workers=sglang_config["num_workers"],
195+
num_gpus_per_worker=sglang_config["num_gpus"],
196+
env=env,
197+
)
198+
if gpu_assignments is None:
199+
pytest.skip("No GPU assignments available")
209200

210201
for i in range(sglang_config["num_workers"]):
211-
port = _find_free_port()
212-
gpu_start = i * sglang_config["num_gpus"]
213-
gpu_end = gpu_start + sglang_config["num_gpus"]
214-
gpu_ids = ",".join(str(gpu_pool[g]) for g in range(gpu_start, gpu_end))
202+
port = _reserve_local_port(10090 + i * 2)
203+
gpu_ids = gpu_assignments[i]
215204

216205
worker_env = dict(env)
217206
worker_env["CUDA_VISIBLE_DEVICES"] = gpu_ids
@@ -257,8 +246,7 @@ def sglang_workers(sglang_config):
257246
)
258247
except RuntimeError as exc:
259248
worker_errors = _collect_exited_worker_errors(procs)
260-
for p in procs:
261-
_kill_proc(p)
249+
terminate_all(procs)
262250
details = (
263251
"\n\n".join(worker_errors)
264252
if worker_errors
@@ -269,8 +257,7 @@ def sglang_workers(sglang_config):
269257
)
270258
except TimeoutError as exc:
271259
worker_errors = _collect_exited_worker_errors(procs)
272-
for p in procs:
273-
_kill_proc(p)
260+
terminate_all(procs)
274261
if worker_errors:
275262
details = "\n\n".join(worker_errors)
276263
pytest.fail(
@@ -281,8 +268,7 @@ def sglang_workers(sglang_config):
281268

282269
exited_after_ready = _collect_exited_worker_errors(procs)
283270
if exited_after_ready:
284-
for p in procs:
285-
_kill_proc(p)
271+
terminate_all(procs)
286272
pytest.fail(
287273
"sglang worker exited after becoming healthy:\n"
288274
+ "\n\n".join(exited_after_ready),
@@ -291,14 +277,13 @@ def sglang_workers(sglang_config):
291277

292278
yield workers
293279

294-
for p in procs:
295-
_kill_proc(p)
280+
terminate_all(procs)
296281

297282

298283
@pytest.fixture(scope="module")
299284
def router_url(sglang_workers):
300285
"""Launch a real router connected to real sglang workers."""
301-
port = _find_free_port()
286+
port = _reserve_local_port(12090)
302287
worker_urls = [w.url for w in sglang_workers]
303288
cmd = [
304289
PYTHON,
@@ -326,11 +311,11 @@ def router_url(sglang_workers):
326311
try:
327312
_wait_healthy(url, 30, label="router", proc=proc)
328313
except Exception:
329-
_kill_proc(proc)
314+
terminate_all([proc])
330315
raise
331316

332317
yield url
333-
_kill_proc(proc)
318+
terminate_all([proc])
334319

335320

336321
# -- Tests ------------------------------------------------------------------

0 commit comments

Comments
 (0)