Skip to content

Commit d7a5bf0

Browse files
committed
refactor(BA-5528): generalize chat client and address review comments
Address review comments from #11344: - Drop chat_dto.py and switch the SDK to plain dict[str, Any] for both request and response, so it doesn't try to track every runtime variant's extension fields (vllm reasoning_content, tool_calls, etc.) - Rename DeploymentChatClient -> InferenceChatClient and decouple it from the vllm runtime variant: works against any OpenAI-compatible endpoint (vllm, tgi, sglang, nim) and exposes a configurable path plus a list_models helper - Rename the cached api key field vllm_api_key -> api_key throughout the cache schema, CLI options, show output, and tests - chat-config set: --token is now optional and pairs with a new --no-token flag for deployments started without --api-key. The served model name is auto-discovered via GET /v1/models (option B from the discussion) so users no longer have to know it - chat: replace the local _abort helper with click.ClickException, validate --max-tokens via click.IntRange(min=1) and the sampling knobs via click.FloatRange, and add --top-p, --frequency-penalty, --presence-penalty, --seed, --stop options - inference_chat client: add ClientTimeout (sock_connect/sock_read) to the owned aiohttp session and normalize trailing slashes when building the chat / models URL - cache loader: tolerate corrupted JSON (OSError/JSONDecodeError) and skip individual malformed entries instead of aborting the whole load - tests: drop redundant atomic-write/permission-reset cases, add loader resilience cases, and shorten the changelog entry
1 parent 73b2c8a commit d7a5bf0

9 files changed

Lines changed: 520 additions & 489 deletions

File tree

changes/11344.feature.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
Add `./bai deployment chat` and `./bai deployment chat-config` v2 CLI commands for one-shot OpenAI-compatible chat with deployed vLLM models, including a local cache (`~/.backend.ai/deployment_chat.json`, `0600`) of per-deployment endpoint URLs and API keys.
1+
Add `./bai deployment chat` for one-shot OpenAI-compatible chat against deployed inference services.

src/ai/backend/client/cli/v2/deployment/chat.py

Lines changed: 104 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
from __future__ import annotations
44

55
import asyncio
6-
import sys
76
from collections.abc import Awaitable, Callable
87
from datetime import UTC, datetime
9-
from typing import TYPE_CHECKING, Any, NoReturn
8+
from typing import Any
109
from uuid import UUID
1110

1211
import click
@@ -18,15 +17,7 @@
1817
load_chat_cache,
1918
save_chat_cache,
2019
)
21-
from ai.backend.client.cli.v2.helpers import create_v2_registry, load_v2_config, print_result
22-
23-
if TYPE_CHECKING:
24-
from ai.backend.client.v2.v2_registry import V2ClientRegistry
25-
26-
27-
def _abort(message: str) -> NoReturn:
28-
click.echo(message, err=True)
29-
sys.exit(1)
20+
from ai.backend.client.cli.v2.helpers import create_v2_registry, load_v2_config
3021

3122

3223
def _run_async(coro_fn: Callable[[], Awaitable[None]]) -> None:
@@ -40,19 +31,7 @@ def _run_async(coro_fn: Callable[[], Awaitable[None]]) -> None:
4031
msg = data.get("msg", "") if isinstance(data, dict) else ""
4132
status = e.args[0] if e.args else "?"
4233
detail = title or msg or str(e)
43-
click.echo(f"Error ({status}): {detail}", err=True)
44-
sys.exit(1)
45-
46-
47-
async def _resolve_endpoint_url(registry: V2ClientRegistry, deployment_id: UUID) -> str:
48-
deployment = await registry.deployment.get(deployment_id)
49-
endpoint_url = deployment.network_access.endpoint_url
50-
if not endpoint_url:
51-
raise click.ClickException(
52-
f"Deployment {deployment_id} has no endpoint_url yet "
53-
"(it may still be provisioning). Wait until the deployment is ACTIVE."
54-
)
55-
return endpoint_url
34+
raise click.ClickException(f"{status}: {detail}") from e
5635

5736

5837
@click.command(name="chat")
@@ -62,106 +41,168 @@ async def _resolve_endpoint_url(registry: V2ClientRegistry, deployment_id: UUID)
6241
"--model",
6342
default=None,
6443
type=str,
65-
help="Model name to send (defaults to cached default_model or 'default').",
44+
help="Model name to send (defaults to cached default_model).",
6645
)
6746
@click.option(
6847
"--temperature",
6948
default=None,
70-
type=float,
49+
type=click.FloatRange(min=0.0, max=2.0),
7150
help="Sampling temperature.",
7251
)
7352
@click.option(
74-
"--max-tokens",
53+
"--top-p",
54+
default=None,
55+
type=click.FloatRange(min=0.0, max=1.0),
56+
help="Nucleus sampling probability mass.",
57+
)
58+
@click.option(
59+
"--frequency-penalty",
60+
default=None,
61+
type=click.FloatRange(min=-2.0, max=2.0),
62+
help="Penalty for token frequency.",
63+
)
64+
@click.option(
65+
"--presence-penalty",
66+
default=None,
67+
type=click.FloatRange(min=-2.0, max=2.0),
68+
help="Penalty for token presence.",
69+
)
70+
@click.option(
71+
"--seed",
7572
default=None,
7673
type=int,
74+
help="Random seed for deterministic sampling.",
75+
)
76+
@click.option(
77+
"--stop",
78+
multiple=True,
79+
type=str,
80+
help="Stop sequence (repeatable).",
81+
)
82+
@click.option(
83+
"--max-tokens",
84+
default=None,
85+
type=click.IntRange(min=1),
7786
help="Maximum number of tokens to generate.",
7887
)
7988
def chat(
8089
deployment_id: UUID,
8190
content: str,
8291
model: str | None,
8392
temperature: float | None,
93+
top_p: float | None,
94+
frequency_penalty: float | None,
95+
presence_penalty: float | None,
96+
seed: int | None,
97+
stop: tuple[str, ...],
8498
max_tokens: int | None,
8599
) -> None:
86-
"""Send a one-shot chat completion request to a deployed vLLM model."""
87-
from ai.backend.client.exceptions import BackendAPIError
88-
from ai.backend.client.v2.chat_dto import ChatCompletionRequest, ChatMessage
89-
from ai.backend.client.v2.domains_v2.deployment_chat import (
90-
DeploymentChatAuthError,
91-
DeploymentChatClient,
100+
"""Send a one-shot chat completion request to a deployed model."""
101+
import json
102+
import sys
103+
104+
from ai.backend.client.v2.domains_v2.inference_chat import (
105+
InferenceChatAuthError,
106+
InferenceChatClient,
92107
)
93108

94109
config = load_v2_config()
95110

96111
try:
97112
cache = load_chat_cache()
98113
except IncompatibleChatCacheError as e:
99-
_abort(str(e))
114+
raise click.ClickException(str(e)) from e
100115

101116
entry = cache.get(deployment_id)
102117

103-
async def _ensure_endpoint() -> DeploymentChatCacheEntry:
118+
async def _resolve_endpoint() -> DeploymentChatCacheEntry:
104119
if entry is not None and entry.endpoint_url:
105120
return entry
106121
registry = await create_v2_registry(config)
107122
try:
108-
endpoint_url = await _resolve_endpoint_url(registry, deployment_id)
123+
deployment = await registry.deployment.get(deployment_id)
109124
finally:
110125
await registry.close()
126+
endpoint_url = deployment.network_access.endpoint_url
127+
if not endpoint_url:
128+
raise click.ClickException(
129+
f"Deployment {deployment_id} has no endpoint_url yet "
130+
"(it may still be provisioning). Wait until the deployment is READY."
131+
)
111132
new_entry = DeploymentChatCacheEntry(
112133
endpoint_url=endpoint_url,
113-
vllm_api_key=entry.vllm_api_key if entry is not None else None,
134+
api_key=entry.api_key if entry is not None else None,
114135
default_model=entry.default_model if entry is not None else None,
115136
last_synced_at=datetime.now(UTC),
116137
)
117138
cache.upsert(deployment_id, new_entry)
118139
save_chat_cache(cache)
119140
return new_entry
120141

142+
def _build_request_body(model_name: str) -> dict[str, Any]:
143+
body: dict[str, Any] = {
144+
"model": model_name,
145+
"messages": [{"role": "user", "content": content}],
146+
}
147+
if temperature is not None:
148+
body["temperature"] = temperature
149+
if top_p is not None:
150+
body["top_p"] = top_p
151+
if frequency_penalty is not None:
152+
body["frequency_penalty"] = frequency_penalty
153+
if presence_penalty is not None:
154+
body["presence_penalty"] = presence_penalty
155+
if seed is not None:
156+
body["seed"] = seed
157+
if stop:
158+
body["stop"] = list(stop)
159+
if max_tokens is not None:
160+
body["max_tokens"] = max_tokens
161+
return body
162+
121163
async def _run() -> None:
122-
resolved = await _ensure_endpoint()
123-
if resolved.vllm_api_key is None:
124-
_abort(
125-
f"No vLLM API key registered for deployment {deployment_id}.\n"
126-
"Register one with:\n"
127-
f" ./bai deployment chat-config set {deployment_id} --token <vllm_api_key>"
164+
from ai.backend.client.exceptions import BackendAPIError
165+
166+
resolved = await _resolve_endpoint()
167+
request_model = model or resolved.default_model
168+
if request_model is None:
169+
raise click.ClickException(
170+
f"No --model given and no default_model cached for deployment {deployment_id}.\n"
171+
"Set one with:\n"
172+
f" ./bai deployment chat-config set {deployment_id} --token <api_key>\n"
173+
"(this auto-discovers the served model from the inference endpoint)."
128174
)
129175

130-
request_model = model or resolved.default_model or "default"
131-
chat_request = ChatCompletionRequest(
132-
model=request_model,
133-
messages=[ChatMessage(role="user", content=content)],
134-
temperature=temperature,
135-
max_tokens=max_tokens,
136-
)
137-
138-
async with DeploymentChatClient(
176+
body = _build_request_body(request_model)
177+
async with InferenceChatClient(
139178
skip_ssl_verification=config.skip_ssl_verification,
140-
) as chat_client:
179+
) as client:
141180
try:
142-
response = await chat_client.chat_completion(
181+
response = await client.chat_completion(
143182
resolved.endpoint_url,
144-
resolved.vllm_api_key,
145-
chat_request,
183+
resolved.api_key,
184+
body,
146185
)
147-
except DeploymentChatAuthError:
186+
except InferenceChatAuthError as e:
148187
invalidated = DeploymentChatCacheEntry(
149188
endpoint_url=resolved.endpoint_url,
150-
vllm_api_key=None,
189+
api_key=None,
151190
default_model=resolved.default_model,
152191
last_synced_at=datetime.now(UTC),
153192
)
154193
cache.upsert(deployment_id, invalidated)
155194
save_chat_cache(cache)
156-
_abort(
195+
raise click.ClickException(
157196
f"The inference endpoint rejected the configured API key for "
158197
f"deployment {deployment_id}. The cached key has been cleared.\n"
159198
"Register a new one with:\n"
160-
f" ./bai deployment chat-config set {deployment_id} --token <vllm_api_key>"
161-
)
199+
f" ./bai deployment chat-config set {deployment_id} --token <api_key>"
200+
) from e
162201
except BackendAPIError as e:
163-
_abort(f"Inference endpoint error ({e.status} {e.reason}): {e.data}")
164-
print_result(response)
202+
raise click.ClickException(
203+
f"Inference endpoint error ({e.status} {e.reason}): {e.data}"
204+
) from e
205+
sys.stdout.write(json.dumps(response, indent=2, ensure_ascii=False, default=str) + "\n")
165206

166207
_run_async(_run)
167208

0 commit comments

Comments
 (0)