2424import base64
2525import os
2626import shutil
27- import signal
28- import socket
2927import subprocess
3028import sys
3129from pathlib import Path
3230
3331import httpx
3432import pytest
3533
36- from sglang_diffusion_routing .launcher .utils import wait_for_health
34+ from sglang_diffusion_routing .launcher .utils import (
35+ build_gpu_assignments ,
36+ reserve_available_port ,
37+ resolve_gpu_pool ,
38+ terminate_all ,
39+ wait_for_health ,
40+ )
3741
3842REPO_ROOT = Path (__file__ ).resolve ().parents [2 ]
3943PYTHON = sys .executable
@@ -59,22 +63,19 @@ def _has_sglang() -> bool:
5963
6064
6165def _gpu_count () -> int :
62- try :
63- import torch
64-
65- return torch .cuda .device_count ()
66- except Exception :
67- return 0
66+ gpu_pool = resolve_gpu_pool (worker_gpu_ids = None , env = os .environ )
67+ return len (gpu_pool ) if gpu_pool else 0
6868
6969
7070def _get_env (key : str , default : str ) -> str :
7171 return os .environ .get (key , default )
7272
7373
74- def _find_free_port () -> int :
75- with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
76- s .bind (("127.0.0.1" , 0 ))
77- return s .getsockname ()[1 ]
74+ _USED_PORTS : set [int ] = set ()
75+
76+
77+ def _reserve_local_port (preferred_port : int ) -> int :
78+ return reserve_available_port ("127.0.0.1" , preferred_port , _USED_PORTS )
7879
7980
8081def _env () -> dict [str , str ]:
@@ -97,23 +98,6 @@ def _wait_healthy(
9798 )
9899
99100
100- def _kill_proc (proc : subprocess .Popen ) -> None :
101- if proc .poll () is not None :
102- return
103- try :
104- os .killpg (proc .pid , signal .SIGTERM )
105- except (ProcessLookupError , PermissionError ):
106- pass
107- try :
108- proc .wait (timeout = 15 )
109- except subprocess .TimeoutExpired :
110- try :
111- os .killpg (proc .pid , signal .SIGKILL )
112- except (ProcessLookupError , PermissionError ):
113- pass
114- proc .wait (timeout = 5 )
115-
116-
117101def _read_stderr_snippet (proc : subprocess .Popen , max_bytes : int = 8192 ) -> str :
118102 if proc .stderr is None :
119103 return ""
@@ -205,13 +189,18 @@ def sglang_workers(sglang_config):
205189 workers = []
206190 procs = []
207191 env = _env ()
208- gpu_pool = list (range (_gpu_count ()))
192+ gpu_assignments = build_gpu_assignments (
193+ worker_gpu_ids = None ,
194+ num_workers = sglang_config ["num_workers" ],
195+ num_gpus_per_worker = sglang_config ["num_gpus" ],
196+ env = env ,
197+ )
198+ if gpu_assignments is None :
199+ pytest .skip ("No GPU assignments available" )
209200
210201 for i in range (sglang_config ["num_workers" ]):
211- port = _find_free_port ()
212- gpu_start = i * sglang_config ["num_gpus" ]
213- gpu_end = gpu_start + sglang_config ["num_gpus" ]
214- gpu_ids = "," .join (str (gpu_pool [g ]) for g in range (gpu_start , gpu_end ))
202+ port = _reserve_local_port (10090 + i * 2 )
203+ gpu_ids = gpu_assignments [i ]
215204
216205 worker_env = dict (env )
217206 worker_env ["CUDA_VISIBLE_DEVICES" ] = gpu_ids
@@ -257,8 +246,7 @@ def sglang_workers(sglang_config):
257246 )
258247 except RuntimeError as exc :
259248 worker_errors = _collect_exited_worker_errors (procs )
260- for p in procs :
261- _kill_proc (p )
249+ terminate_all (procs )
262250 details = (
263251 "\n \n " .join (worker_errors )
264252 if worker_errors
@@ -269,8 +257,7 @@ def sglang_workers(sglang_config):
269257 )
270258 except TimeoutError as exc :
271259 worker_errors = _collect_exited_worker_errors (procs )
272- for p in procs :
273- _kill_proc (p )
260+ terminate_all (procs )
274261 if worker_errors :
275262 details = "\n \n " .join (worker_errors )
276263 pytest .fail (
@@ -281,8 +268,7 @@ def sglang_workers(sglang_config):
281268
282269 exited_after_ready = _collect_exited_worker_errors (procs )
283270 if exited_after_ready :
284- for p in procs :
285- _kill_proc (p )
271+ terminate_all (procs )
286272 pytest .fail (
287273 "sglang worker exited after becoming healthy:\n "
288274 + "\n \n " .join (exited_after_ready ),
@@ -291,14 +277,13 @@ def sglang_workers(sglang_config):
291277
292278 yield workers
293279
294- for p in procs :
295- _kill_proc (p )
280+ terminate_all (procs )
296281
297282
298283@pytest .fixture (scope = "module" )
299284def router_url (sglang_workers ):
300285 """Launch a real router connected to real sglang workers."""
301- port = _find_free_port ( )
286+ port = _reserve_local_port ( 12090 )
302287 worker_urls = [w .url for w in sglang_workers ]
303288 cmd = [
304289 PYTHON ,
@@ -326,11 +311,11 @@ def router_url(sglang_workers):
326311 try :
327312 _wait_healthy (url , 30 , label = "router" , proc = proc )
328313 except Exception :
329- _kill_proc ( proc )
314+ terminate_all ([ proc ] )
330315 raise
331316
332317 yield url
333- _kill_proc ( proc )
318+ terminate_all ([ proc ] )
334319
335320
336321# -- Tests ------------------------------------------------------------------
0 commit comments