Skip to content

Commit caa6f25

Browse files
committed
test: add client test and simplify existing tests
1 parent bc0eaf5 commit caa6f25

File tree

5 files changed

+248
-64
lines changed

5 files changed

+248
-64
lines changed

tests/conftest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
pytest tests/ # All tests
2525
2626
Configuration:
27-
pytest tests/integration/ --sglang-base-url=http://localhost:8000
27+
pytest tests/integration/ --sglang-base-url=http://localhost:30000
2828
pytest tests/integration/ --sglang-model-id=Qwen/Qwen3-4B-Instruct-2507
2929
3030
Or via environment variables:
31-
SGLANG_BASE_URL=http://localhost:8000 pytest tests/integration/
31+
SGLANG_BASE_URL=http://localhost:30000 pytest tests/integration/
3232
"""
3333

3434
import os
@@ -39,8 +39,8 @@ def pytest_addoption(parser):
3939
parser.addoption(
4040
"--sglang-base-url",
4141
action="store",
42-
default=os.environ.get("SGLANG_BASE_URL", "http://localhost:8000"),
43-
help="SGLang server URL (default: http://localhost:8000 or SGLANG_BASE_URL env var)",
42+
default=os.environ.get("SGLANG_BASE_URL", "http://localhost:30000"),
43+
help="SGLang server URL (default: http://localhost:30000 or SGLANG_BASE_URL env var)",
4444
)
4545
parser.addoption(
4646
"--sglang-model-id",

tests/integration/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
and require a running SGLang server.
1919
2020
Configuration (priority: CLI > env var > default):
21-
pytest --sglang-base-url=http://localhost:8000 --sglang-model-id=Qwen/Qwen3-4B-Instruct-2507
21+
pytest --sglang-base-url=http://localhost:30000 --sglang-model-id=Qwen/Qwen3-4B-Instruct-2507
2222
SGLANG_BASE_URL=http://... SGLANG_MODEL_ID=... pytest tests/integration/
2323
"""
2424

tests/integration/test_sglang_integration.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -190,21 +190,22 @@ async def test_incremental_tokenization(self, model):
190190

191191

192192
class TestSSEParsing:
193-
"""Tests for SSE event parsing."""
193+
"""Tests for SSE event parsing via SGLangClient."""
194194

195-
async def test_iter_sse_events(self, model):
196-
"""_iter_sse_events correctly parses SSE stream."""
195+
async def test_client_generate_parses_sse(self, model):
196+
"""SGLangClient.generate() correctly parses SSE stream."""
197197
messages = [{"role": "user", "content": [{"text": "Say 'test'"}]}]
198198

199-
# Manually call the internal stream to test SSE parsing
199+
# Tokenize and call client.generate() directly
200200
input_ids = model.tokenize_prompt_messages(messages, system_prompt=None)
201-
payload = model.build_sglang_payload(input_ids=input_ids)
201+
client = model._get_client()
202202

203-
async with model.client.stream("POST", "/generate", json=payload) as response:
204-
events = []
205-
async for event in model._iter_sse_events(response):
206-
events.append(event)
203+
events = []
204+
async for event in client.generate(input_ids=input_ids):
205+
events.append(event)
207206

208-
# Should have parsed JSON events
207+
# Should have parsed JSON events with expected fields
209208
assert len(events) > 0
210209
assert all(isinstance(e, dict) for e in events)
210+
# Final event should have output_ids
211+
assert "output_ids" in events[-1] or "text" in events[-1]

tests/unit/test_client.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# Copyright 2025 Horizon RL Contributors
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for SGLangClient (mocked, no server required)."""
16+
17+
from unittest.mock import MagicMock, patch
18+
19+
import httpx
20+
import pytest
21+
22+
from strands_sglang.client import RETRYABLE_STATUS_CODES, SGLangClient
23+
24+
25+
class TestSGLangClientInit:
26+
"""Tests for SGLangClient initialization."""
27+
28+
def test_default_config(self):
29+
"""Default configuration values."""
30+
client = SGLangClient("http://localhost:30000")
31+
32+
assert client.base_url == "http://localhost:30000"
33+
assert client.max_retries == 60
34+
assert client.retry_delay == 1.0
35+
36+
def test_base_url_strips_trailing_slash(self):
37+
"""Base URL trailing slash is stripped."""
38+
client = SGLangClient("http://localhost:30000/")
39+
assert client.base_url == "http://localhost:30000"
40+
41+
def test_custom_config(self):
42+
"""Custom configuration is applied."""
43+
client = SGLangClient(
44+
"http://custom:9000",
45+
max_connections=500,
46+
timeout=120.0,
47+
max_retries=10,
48+
retry_delay=2.0,
49+
)
50+
51+
assert client.base_url == "http://custom:9000"
52+
assert client.max_retries == 10
53+
assert client.retry_delay == 2.0
54+
55+
56+
class TestRetryableErrors:
57+
"""Tests for _is_retryable_error method."""
58+
59+
@pytest.fixture
60+
def client(self):
61+
return SGLangClient("http://localhost:30000")
62+
63+
def test_connect_error_is_retryable(self, client):
64+
"""ConnectError is retryable."""
65+
error = httpx.ConnectError("Connection refused")
66+
assert client._is_retryable_error(error) is True
67+
68+
def test_read_timeout_is_retryable(self, client):
69+
"""ReadTimeout is retryable."""
70+
error = httpx.ReadTimeout("Read timed out")
71+
assert client._is_retryable_error(error) is True
72+
73+
def test_pool_timeout_is_retryable(self, client):
74+
"""PoolTimeout is retryable."""
75+
error = httpx.PoolTimeout("Pool exhausted")
76+
assert client._is_retryable_error(error) is True
77+
78+
@pytest.mark.parametrize("status_code", RETRYABLE_STATUS_CODES)
79+
def test_5xx_errors_are_retryable(self, client, status_code):
80+
"""HTTP 5xx errors are retryable."""
81+
response = MagicMock()
82+
response.status_code = status_code
83+
error = httpx.HTTPStatusError("Server error", request=MagicMock(), response=response)
84+
assert client._is_retryable_error(error) is True
85+
86+
def test_400_is_not_retryable(self, client):
87+
"""HTTP 400 is not retryable."""
88+
response = MagicMock()
89+
response.status_code = 400
90+
error = httpx.HTTPStatusError("Bad request", request=MagicMock(), response=response)
91+
assert client._is_retryable_error(error) is False
92+
93+
def test_429_is_not_retryable(self, client):
94+
"""HTTP 429 is not retryable (rate limit is handled by caller)."""
95+
response = MagicMock()
96+
response.status_code = 429
97+
error = httpx.HTTPStatusError("Rate limited", request=MagicMock(), response=response)
98+
assert client._is_retryable_error(error) is False
99+
100+
def test_generic_exception_is_not_retryable(self, client):
101+
"""Generic exceptions are not retryable."""
102+
error = ValueError("Something wrong")
103+
assert client._is_retryable_error(error) is False
104+
105+
106+
class TestSSEParsing:
107+
"""Tests for _iter_sse_events method."""
108+
109+
@pytest.fixture
110+
def client(self):
111+
return SGLangClient("http://localhost:30000")
112+
113+
@pytest.mark.asyncio
114+
async def test_parse_valid_sse_events(self, client):
115+
"""Parse valid SSE events."""
116+
response = MagicMock()
117+
response.aiter_lines.return_value = AsyncIteratorMock(
118+
[
119+
'data: {"text": "Hello"}',
120+
'data: {"text": "Hello world"}',
121+
"data: [DONE]",
122+
]
123+
)
124+
125+
events = [e async for e in client._iter_sse_events(response)]
126+
127+
assert len(events) == 2
128+
assert events[0] == {"text": "Hello"}
129+
assert events[1] == {"text": "Hello world"}
130+
131+
@pytest.mark.asyncio
132+
async def test_skip_empty_lines(self, client):
133+
"""Empty lines are skipped."""
134+
response = MagicMock()
135+
response.aiter_lines.return_value = AsyncIteratorMock(
136+
[
137+
"",
138+
'data: {"text": "test"}',
139+
"",
140+
"data: [DONE]",
141+
]
142+
)
143+
144+
events = [e async for e in client._iter_sse_events(response)]
145+
146+
assert len(events) == 1
147+
assert events[0] == {"text": "test"}
148+
149+
@pytest.mark.asyncio
150+
async def test_skip_non_data_lines(self, client):
151+
"""Non-data lines are skipped."""
152+
response = MagicMock()
153+
response.aiter_lines.return_value = AsyncIteratorMock(
154+
[
155+
"event: message",
156+
'data: {"text": "test"}',
157+
": comment",
158+
"data: [DONE]",
159+
]
160+
)
161+
162+
events = [e async for e in client._iter_sse_events(response)]
163+
164+
assert len(events) == 1
165+
166+
@pytest.mark.asyncio
167+
async def test_skip_malformed_json(self, client):
168+
"""Malformed JSON is skipped."""
169+
response = MagicMock()
170+
response.aiter_lines.return_value = AsyncIteratorMock(
171+
[
172+
"data: not json",
173+
'data: {"text": "valid"}',
174+
"data: [DONE]",
175+
]
176+
)
177+
178+
events = [e async for e in client._iter_sse_events(response)]
179+
180+
assert len(events) == 1
181+
assert events[0] == {"text": "valid"}
182+
183+
184+
class TestHealth:
185+
"""Tests for health method."""
186+
187+
@pytest.mark.asyncio
188+
async def test_health_returns_true_on_200(self):
189+
"""Health returns True on 200 response."""
190+
with patch.object(httpx.AsyncClient, "get") as mock_get:
191+
mock_response = MagicMock()
192+
mock_response.status_code = 200
193+
mock_get.return_value = mock_response
194+
195+
client = SGLangClient("http://localhost:30000")
196+
result = await client.health()
197+
198+
assert result is True
199+
200+
@pytest.mark.asyncio
201+
async def test_health_returns_false_on_error(self):
202+
"""Health returns False on HTTP error."""
203+
with patch.object(httpx.AsyncClient, "get") as mock_get:
204+
mock_get.side_effect = httpx.ConnectError("Connection refused")
205+
206+
client = SGLangClient("http://localhost:30000")
207+
result = await client.health()
208+
209+
assert result is False
210+
211+
212+
class AsyncIteratorMock:
213+
"""Mock async iterator for testing."""
214+
215+
def __init__(self, items):
216+
self.items = items
217+
self.index = 0
218+
219+
def __aiter__(self):
220+
return self
221+
222+
async def __anext__(self):
223+
if self.index >= len(self.items):
224+
raise StopAsyncIteration
225+
item = self.items[self.index]
226+
self.index += 1
227+
return item

tests/unit/test_sglang.py

Lines changed: 5 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -119,46 +119,6 @@ def test_format_prompt_with_tools(self, model, mock_tokenizer):
119119
assert call_kwargs["tokenize"] is False
120120

121121

122-
class TestBuildSglangPayload:
123-
"""Tests for build_sglang_payload method."""
124-
125-
def test_minimal_payload(self, model):
126-
"""Build payload with minimal parameters."""
127-
payload = model.build_sglang_payload(input_ids=[1, 2, 3])
128-
129-
assert payload["input_ids"] == [1, 2, 3]
130-
assert payload["stream"] is True
131-
assert payload["return_logprob"] is True
132-
assert payload["logprob_start_len"] == 0
133-
134-
def test_payload_with_sampling_params(self, model):
135-
"""Build payload with sampling parameters."""
136-
sampling = {"temperature": 0.7, "max_tokens": 100}
137-
payload = model.build_sglang_payload(input_ids=[1, 2, 3], sampling_params=sampling)
138-
139-
assert payload["sampling_params"] == sampling
140-
141-
def test_payload_without_logprobs(self, model):
142-
"""Build payload without logprobs."""
143-
payload = model.build_sglang_payload(input_ids=[1, 2, 3], return_logprob=False)
144-
145-
assert "return_logprob" not in payload
146-
assert "logprob_start_len" not in payload
147-
148-
def test_payload_without_streaming(self, model):
149-
"""Build payload without streaming."""
150-
payload = model.build_sglang_payload(input_ids=[1, 2, 3], stream=False)
151-
152-
assert payload["stream"] is False
153-
154-
def test_payload_with_model_id(self, mock_tokenizer):
155-
"""Build payload with model ID from config."""
156-
model = SGLangModel(tokenizer=mock_tokenizer, model_id="qwen/qwen3-4b")
157-
payload = model.build_sglang_payload(input_ids=[1, 2, 3])
158-
159-
assert payload["model"] == "qwen/qwen3-4b"
160-
161-
162122
class TestTokenizePromptMessages:
163123
"""Tests for tokenize_prompt_messages method."""
164124

@@ -336,7 +296,7 @@ def test_default_config(self, mock_tokenizer):
336296
model = SGLangModel(tokenizer=mock_tokenizer)
337297
config = model.get_config()
338298

339-
assert config["base_url"] == "http://localhost:8000"
299+
assert config["base_url"] == "http://localhost:30000"
340300

341301
def test_custom_base_url(self, mock_tokenizer):
342302
"""Custom base URL is stored correctly."""
@@ -354,15 +314,11 @@ def test_update_config(self, model):
354314
assert config["model_id"] == "new-model"
355315

356316
def test_config_with_timeout_float(self, mock_tokenizer):
357-
"""Configuration with timeout float (connect is always 5.0 like OpenAI)."""
317+
"""Configuration with custom timeout."""
358318
model = SGLangModel(tokenizer=mock_tokenizer, timeout=300.0)
359-
timeout = model._client_config["timeout"]
360-
assert timeout.connect == 5.0 # Fixed like OpenAI
361-
assert timeout.read == 300.0
319+
assert model._timeout == 300.0
362320

363321
def test_config_with_default_timeout(self, mock_tokenizer):
364-
"""Configuration with default timeout (600s like OpenAI)."""
322+
"""Configuration with default timeout (None = infinite, like SLIME)."""
365323
model = SGLangModel(tokenizer=mock_tokenizer)
366-
timeout = model._client_config["timeout"]
367-
assert timeout.connect == 5.0
368-
assert timeout.read == 600.0 # Default 10min like OpenAI
324+
assert model._timeout is None # Infinite timeout by default

0 commit comments

Comments
 (0)