Skip to content

Commit e85a853

Browse files
committed
refactor(BA-5528): align DeploymentChatClient session lifecycle, add Args type
Address review comments on PR #11344: - Drop _owns_session and the optional session= kwarg on DeploymentChatClient. Match BackendAIAuthClient: __init__ takes a pre-built session, factory method create() builds one, close() always closes. Removes the dual-ownership branch. - Introduce DeploymentChatClientArgs (frozen dataclass) for connection knobs (skip_ssl_verification, connect_timeout, read_timeout). Callers use DeploymentChatClient.create(args) instead of passing multiple kwargs to the constructor. - Rename chat_completion's 'request' parameter to 'body'. - Tests: rename the cache-entry helper to _make_entry, the chat-body helper to _make_body. Drop TestExternalSession since the new contract is 'whatever you pass to __init__ gets closed'.
1 parent b42bc94 commit e85a853

5 files changed

Lines changed: 55 additions & 57 deletions

File tree

src/ai/backend/client/cli/v2/deployment/chat/commands.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def chat(
9393
from ai.backend.client.v2.deployment_chat import (
9494
DeploymentChatAuthError,
9595
DeploymentChatClient,
96+
DeploymentChatClientArgs,
9697
)
9798

9899
connection = load_v2_config()
@@ -150,9 +151,10 @@ async def _run() -> None:
150151
"messages": [{"role": "user", "content": content}],
151152
}
152153
api_key = chat_config_store.get_token(deployment_id)
153-
async with DeploymentChatClient(
154+
client_args = DeploymentChatClientArgs(
154155
skip_ssl_verification=connection.skip_ssl_verification,
155-
) as client:
156+
)
157+
async with await DeploymentChatClient.create(client_args) as client:
156158
try:
157159
response = await client.chat_completion(
158160
endpoint_entry.endpoint_url,
@@ -281,9 +283,11 @@ async def _discover_model(
281283
from ai.backend.client.v2.deployment_chat import (
282284
DeploymentChatAuthError,
283285
DeploymentChatClient,
286+
DeploymentChatClientArgs,
284287
)
285288

286-
async with DeploymentChatClient(skip_ssl_verification=skip_ssl_verification) as client:
289+
client_args = DeploymentChatClientArgs(skip_ssl_verification=skip_ssl_verification)
290+
async with await DeploymentChatClient.create(client_args) as client:
287291
try:
288292
payload = await client.list_models(endpoint_url, api_key)
289293
except (DeploymentChatAuthError, BackendAPIError, BackendClientError):

src/ai/backend/client/v2/deployment_chat.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from __future__ import annotations
1414

15+
from dataclasses import dataclass
1516
from types import TracebackType
1617
from typing import Any, Self
1718

@@ -28,31 +29,32 @@ class DeploymentChatAuthError(BackendAPIError):
2829
"""Raised when the inference endpoint rejects the configured API key."""
2930

3031

32+
@dataclass(frozen=True)
33+
class DeploymentChatClientArgs:
34+
"""Connection knobs for :meth:`DeploymentChatClient.create`."""
35+
36+
skip_ssl_verification: bool = False
37+
connect_timeout: float | None = 10.0
38+
read_timeout: float | None = 300.0
39+
40+
3141
class DeploymentChatClient:
3242
"""Direct HTTP client for OpenAI-compatible inference endpoints."""
3343

3444
_session: aiohttp.ClientSession
35-
_owns_session: bool
3645

37-
def __init__(
38-
self,
39-
*,
40-
session: aiohttp.ClientSession | None = None,
41-
skip_ssl_verification: bool = False,
42-
connect_timeout: float | None = 10.0,
43-
read_timeout: float | None = 300.0,
44-
) -> None:
45-
if session is not None:
46-
self._session = session
47-
self._owns_session = False
48-
else:
49-
connector = aiohttp.TCPConnector(ssl=not skip_ssl_verification)
50-
timeout = aiohttp.ClientTimeout(
51-
sock_connect=connect_timeout,
52-
sock_read=read_timeout,
53-
)
54-
self._session = aiohttp.ClientSession(connector=connector, timeout=timeout)
55-
self._owns_session = True
46+
def __init__(self, session: aiohttp.ClientSession) -> None:
47+
self._session = session
48+
49+
@classmethod
50+
async def create(cls, args: DeploymentChatClientArgs) -> Self:
51+
connector = aiohttp.TCPConnector(ssl=not args.skip_ssl_verification)
52+
timeout = aiohttp.ClientTimeout(
53+
sock_connect=args.connect_timeout,
54+
sock_read=args.read_timeout,
55+
)
56+
session = aiohttp.ClientSession(connector=connector, timeout=timeout)
57+
return cls(session)
5658

5759
async def __aenter__(self) -> Self:
5860
return self
@@ -66,20 +68,20 @@ async def __aexit__(
6668
await self.close()
6769

6870
async def close(self) -> None:
69-
if self._owns_session and not self._session.closed:
71+
if not self._session.closed:
7072
await self._session.close()
7173

7274
async def chat_completion(
7375
self,
7476
endpoint_url: str,
7577
api_key: str | None,
76-
request: dict[str, Any],
78+
body: dict[str, Any],
7779
*,
7880
path: str = DEFAULT_CHAT_PATH,
7981
) -> dict[str, Any]:
8082
"""POST a chat completion request to the deployment endpoint."""
8183
target = self._build_url(endpoint_url, path)
82-
return await self._post(target, api_key, request)
84+
return await self._post(target, api_key, body)
8385

8486
async def list_models(
8587
self,

tests/unit/client/cli/test_deployment_chat_types.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
)
1111

1212

13-
def _entry(
13+
def _make_entry(
1414
*,
1515
endpoint: str = "https://infer.example.test/api",
1616
default_model: str | None = None,
@@ -24,32 +24,32 @@ def _entry(
2424

2525
class TestEntryFormatSummary:
2626
def test_format_summary_returns_lines(self) -> None:
27-
entry = _entry(default_model="meta/test-model")
27+
entry = _make_entry(default_model="meta/test-model")
2828
lines = entry.format_summary()
2929
assert any("endpoint_url" in line for line in lines)
3030
assert any("meta/test-model" in line for line in lines)
3131
assert any("last_synced_at" in line for line in lines)
3232

3333
def test_format_summary_dash_for_missing_default_model(self) -> None:
34-
entry = _entry(default_model=None)
34+
entry = _make_entry(default_model=None)
3535
lines = entry.format_summary()
3636
assert any("default_model : -" in line for line in lines)
3737

3838

3939
class TestCacheMutations:
40-
def test_upsert_overwrites_existing_entry(self) -> None:
40+
def test_upsert_overwrites_existing_make_entry(self) -> None:
4141
cache = DeploymentChatCache()
4242
dep_id = uuid4()
43-
cache.upsert(dep_id, _entry(default_model="m1"))
44-
cache.upsert(dep_id, _entry(default_model="m2"))
43+
cache.upsert(dep_id, _make_entry(default_model="m1"))
44+
cache.upsert(dep_id, _make_entry(default_model="m2"))
4545
stored = cache.get(dep_id)
4646
assert stored is not None
4747
assert stored.default_model == "m2"
4848

4949
def test_remove_returns_true_when_present(self) -> None:
5050
cache = DeploymentChatCache()
5151
dep_id = uuid4()
52-
cache.upsert(dep_id, _entry())
52+
cache.upsert(dep_id, _make_entry())
5353
assert cache.remove(dep_id) is True
5454
assert cache.get(dep_id) is None
5555

tests/unit/client/cli/test_deployment_chat_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828

2929

30-
def _entry(
30+
def _make_entry(
3131
*,
3232
endpoint: str = "https://infer.example.test/api",
3333
default_model: str | None = None,
@@ -44,11 +44,11 @@ def test_load_returns_empty_when_file_missing(self, tmp_path: Path) -> None:
4444
cache = load_chat_cache(tmp_path / "missing.json")
4545
assert cache.deployments == {}
4646

47-
def test_save_then_load_preserves_entry(self, tmp_path: Path) -> None:
47+
def test_save_then_load_preserves_make_entry(self, tmp_path: Path) -> None:
4848
path = tmp_path / "deployment_chat.json"
4949
cache = DeploymentChatCache()
5050
dep_id = uuid4()
51-
original = _entry(default_model="gpt-test")
51+
original = _make_entry(default_model="gpt-test")
5252
cache.upsert(dep_id, original)
5353
save_chat_cache(cache, path)
5454

@@ -96,7 +96,7 @@ class TestPermissions:
9696
def test_save_chat_cache_enforces_0600(self, tmp_path: Path) -> None:
9797
path = tmp_path / "cache.json"
9898
cache = DeploymentChatCache()
99-
cache.upsert(uuid4(), _entry())
99+
cache.upsert(uuid4(), _make_entry())
100100
save_chat_cache(cache, path)
101101
assert stat.S_IMODE(path.stat().st_mode) == 0o600
102102

tests/unit/client/v2/test_deployment_chat_client.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from collections.abc import AsyncIterator, Awaitable, Callable
44
from typing import Any
55

6-
import aiohttp
76
import pytest
87
from aiohttp import web
98

109
from ai.backend.client.exceptions import BackendAPIError, BackendClientError
1110
from ai.backend.client.v2.deployment_chat import (
1211
DeploymentChatAuthError,
1312
DeploymentChatClient,
13+
DeploymentChatClientArgs,
1414
)
1515

1616
HandlerFn = Callable[[web.Request], Awaitable[web.StreamResponse]]
@@ -61,14 +61,14 @@ async def wrapped(request: web.Request) -> web.StreamResponse:
6161

6262
@pytest.fixture
6363
async def chat_client() -> AsyncIterator[DeploymentChatClient]:
64-
client = DeploymentChatClient()
64+
client = await DeploymentChatClient.create(DeploymentChatClientArgs())
6565
try:
6666
yield client
6767
finally:
6868
await client.close()
6969

7070

71-
def _request_body() -> dict[str, Any]:
71+
def _make_body() -> dict[str, Any]:
7272
return {
7373
"model": "meta/test-model",
7474
"messages": [{"role": "user", "content": "hello"}],
@@ -96,7 +96,7 @@ async def handler(_request: web.Request) -> web.Response:
9696
resp = await chat_client.chat_completion(
9797
server.base_url,
9898
"sk-test-token",
99-
_request_body(),
99+
_make_body(),
100100
)
101101
finally:
102102
await server.stop()
@@ -105,7 +105,7 @@ async def handler(_request: web.Request) -> web.Response:
105105
assert server.recorded["path"] == "/v1/chat/completions"
106106
assert server.recorded["headers"]["Authorization"] == "Bearer sk-test-token"
107107
assert server.recorded["headers"]["Content-Type"] == "application/json"
108-
assert server.recorded["json"] == _request_body()
108+
assert server.recorded["json"] == _make_body()
109109
assert resp["choices"][0]["message"]["content"] == "hi"
110110

111111
async def test_endpoint_url_already_ending_in_chat_completions(
@@ -117,7 +117,7 @@ async def handler(_request: web.Request) -> web.Response:
117117
server = await _start_server("POST", "/v1/chat/completions", handler)
118118
try:
119119
full_url = f"{server.base_url}/v1/chat/completions"
120-
await chat_client.chat_completion(full_url, "sk-x", _request_body())
120+
await chat_client.chat_completion(full_url, "sk-x", _make_body())
121121
finally:
122122
await server.stop()
123123
assert server.recorded["path"] == "/v1/chat/completions"
@@ -131,7 +131,7 @@ async def handler(_request: web.Request) -> web.Response:
131131
server = await _start_server("POST", "/v1/chat/completions", handler)
132132
try:
133133
full_url = f"{server.base_url}/v1/chat/completions/"
134-
await chat_client.chat_completion(full_url, "sk-x", _request_body())
134+
await chat_client.chat_completion(full_url, "sk-x", _make_body())
135135
finally:
136136
await server.stop()
137137
assert server.recorded["path"] == "/v1/chat/completions"
@@ -144,7 +144,7 @@ async def handler(_request: web.Request) -> web.Response:
144144

145145
server = await _start_server("POST", "/v1/chat/completions", handler)
146146
try:
147-
await chat_client.chat_completion(server.base_url, None, _request_body())
147+
await chat_client.chat_completion(server.base_url, None, _make_body())
148148
finally:
149149
await server.stop()
150150
assert "Authorization" not in server.recorded["headers"]
@@ -177,7 +177,7 @@ async def handler(_request: web.Request) -> web.Response:
177177
server = await _start_server("POST", "/v1/chat/completions", handler)
178178
try:
179179
with pytest.raises(DeploymentChatAuthError) as exc_info:
180-
await chat_client.chat_completion(server.base_url, "bad", _request_body())
180+
await chat_client.chat_completion(server.base_url, "bad", _make_body())
181181
finally:
182182
await server.stop()
183183
assert exc_info.value.status == 401
@@ -191,7 +191,7 @@ async def handler(_request: web.Request) -> web.Response:
191191
server = await _start_server("POST", "/v1/chat/completions", handler)
192192
try:
193193
with pytest.raises(DeploymentChatAuthError):
194-
await chat_client.chat_completion(server.base_url, "bad", _request_body())
194+
await chat_client.chat_completion(server.base_url, "bad", _make_body())
195195
finally:
196196
await server.stop()
197197

@@ -206,7 +206,7 @@ async def handler(_request: web.Request) -> web.Response:
206206
server = await _start_server("POST", "/v1/chat/completions", handler)
207207
try:
208208
with pytest.raises(BackendAPIError) as exc_info:
209-
await chat_client.chat_completion(server.base_url, "sk", _request_body())
209+
await chat_client.chat_completion(server.base_url, "sk", _make_body())
210210
finally:
211211
await server.stop()
212212
assert not isinstance(exc_info.value, DeploymentChatAuthError)
@@ -223,14 +223,6 @@ async def handler(_request: web.Request) -> web.Response:
223223
server = await _start_server("POST", "/v1/chat/completions", handler)
224224
try:
225225
with pytest.raises(BackendClientError):
226-
await chat_client.chat_completion(server.base_url, "sk", _request_body())
226+
await chat_client.chat_completion(server.base_url, "sk", _make_body())
227227
finally:
228228
await server.stop()
229-
230-
231-
class TestExternalSession:
232-
async def test_does_not_close_externally_owned_session(self) -> None:
233-
async with aiohttp.ClientSession() as external:
234-
client = DeploymentChatClient(session=external)
235-
await client.close()
236-
assert external.closed is False

0 commit comments

Comments
 (0)