Skip to content

Commit 39e915e

Browse files
committed
fix(util): fall back sglang sync utils
1 parent ace4bef commit 39e915e

File tree

3 files changed

+46
-19
lines changed

3 files changed

+46
-19
lines changed

src/strands_env/cli/utils.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -244,22 +244,19 @@ def build_model_factory(config: ModelConfig, max_concurrency: int) -> ModelFacto
244244

245245
def _build_sglang_model_factory(config: ModelConfig, max_concurrency: int, sampling: dict) -> ModelFactory:
246246
"""Build SGLang model factory."""
247-
import asyncio
248-
249-
from strands_env.utils.sglang import get_cached_client, get_cached_tokenizer
250-
251-
client = get_cached_client(config.base_url, max_concurrency)
247+
from strands_env.utils.sglang import check_server_health, get_cached_client, get_cached_tokenizer, get_model_id
252248

253249
# Check server health before proceeding
254250
try:
255-
if not asyncio.run(client.health()):
256-
raise ConnectionError(f"SGLang server at {config.base_url} is not healthy")
257-
except Exception as e:
258-
raise click.ClickException(f"SGLang server at {config.base_url} is not reachable: {e}")
251+
check_server_health(config.base_url)
252+
except ConnectionError as e:
253+
raise click.ClickException(str(e))
254+
255+
client = get_cached_client(config.base_url, max_concurrency)
259256

260257
# Resolve and backfill model_id/tokenizer_path for reproducibility
261258
if not config.model_id:
262-
config.model_id = asyncio.run(client.get_model_info())["model_path"]
259+
config.model_id = get_model_id(config.base_url)
263260
if not config.tokenizer_path:
264261
config.tokenizer_path = config.model_id
265262

src/strands_env/utils/sglang.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from functools import lru_cache
2020
from typing import TYPE_CHECKING, Any
2121

22+
import httpx
2223
from strands_sglang import SGLangClient
2324

2425
if TYPE_CHECKING:
@@ -68,3 +69,35 @@ def clear_clients() -> None:
6869
def clear_tokenizers() -> None:
6970
"""Clear all cached tokenizer instances."""
7071
get_cached_tokenizer.cache_clear()
72+
73+
74+
def check_server_health(base_url: str, timeout: float = 5.0) -> None:
75+
"""Check if the SGLang server is reachable.
76+
77+
Args:
78+
base_url: Base URL of the SGLang server.
79+
timeout: Request timeout in seconds.
80+
81+
Raises:
82+
ConnectionError: If the server is not reachable or unhealthy.
83+
"""
84+
try:
85+
response = httpx.get(f"{base_url}/health", timeout=timeout)
86+
response.raise_for_status()
87+
except httpx.HTTPError as e:
88+
raise ConnectionError(f"SGLang server at {base_url} is not reachable: {e}") from e
89+
90+
91+
def get_model_id(base_url: str, timeout: float = 5.0) -> str:
92+
"""Get the model ID from the SGLang server.
93+
94+
Args:
95+
base_url: Base URL of the SGLang server.
96+
timeout: Request timeout in seconds.
97+
98+
Returns:
99+
The model path/ID from the server.
100+
"""
101+
response = httpx.get(f"{base_url}/get_model_info", timeout=timeout)
102+
response.raise_for_status()
103+
return response.json()["model_path"]

tests/integration/conftest.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99
SGLANG_BASE_URL=http://... pytest tests/integration/
1010
"""
1111

12-
import asyncio
13-
1412
import pytest
1513
from strands_sglang import SGLangClient
1614
from transformers import AutoTokenizer
1715

1816
from strands_env.core.models import DEFAULT_SAMPLING_PARAMS, sglang_model_factory
17+
from strands_env.utils.sglang import check_server_health, get_model_id
1918

2019
# Mark all tests in this directory as integration tests
2120
pytestmark = pytest.mark.integration
@@ -30,19 +29,17 @@ def sglang_base_url(request):
3029
@pytest.fixture(scope="session")
3130
def sglang_client(sglang_base_url):
3231
"""Shared SGLang client for connection pooling. Skips all tests if server is unreachable."""
33-
client = SGLangClient(sglang_base_url)
3432
try:
35-
if not asyncio.run(client.health()):
36-
pytest.skip(f"SGLang server at {sglang_base_url} is not healthy")
37-
except Exception:
33+
check_server_health(sglang_base_url)
34+
except ConnectionError:
3835
pytest.skip(f"SGLang server not reachable at {sglang_base_url}")
39-
return client
36+
return SGLangClient(sglang_base_url)
4037

4138

4239
@pytest.fixture(scope="session")
43-
def sglang_model_id(sglang_client):
40+
def sglang_model_id(sglang_base_url):
4441
"""Auto-detect model ID from the running SGLang server."""
45-
return asyncio.run(sglang_client.get_model_info())["model_path"]
42+
return get_model_id(sglang_base_url)
4643

4744

4845
@pytest.fixture(scope="session")

0 commit comments

Comments
 (0)