Skip to content

Commit c5e2505

Browse files
committed
refactor(BA-5528): apply review feedback for chat CLI
Address review comments on PR #11344: - chat.py: - Drop the auto-clear of the cached API key on inference 401/403 — it was deleting user-supplied config out from under them. Just raise the error and ask the user to re-register. - Use print() instead of sys.stdout.write() for the response payload. - chat_config.py: - Remove --no-token; clearing is the dedicated chat-config clear command's job. Resolved-key handling collapses to a single expression. - Use print() instead of click.echo() for status output. - Parse the inference endpoint's /v1/models response with a typed Pydantic model (_ServedModelsResponse) instead of manual dict.get walking. - _print_entry now delegates the entry portion to DeploymentChatCacheEntry.format_summary() so the per-entry fields are owned by the cache type. - deployment_chat_cache.py / deployment_chat_config.py: - Drop schema_version as a Pydantic field on the wrapper model. The version is metadata, not data — emit it manually around model_dump in save_*, and check it manually in load_* before validating individual records. - DeploymentChatCacheEntry gains a format_summary() method returning the endpoint/default_model/last_synced_at lines so consumers don't duplicate that formatting.
1 parent 93d4578 commit c5e2505

5 files changed

Lines changed: 61 additions & 50 deletions

File tree

src/ai/backend/client/cli/v2/deployment/chat.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from ai.backend.client.cli.v2.deployment_chat_config import (
2222
IncompatibleChatConfigError,
2323
load_chat_config,
24-
save_chat_config,
2524
)
2625
from ai.backend.client.cli.v2.helpers import create_v2_registry, load_v2_config
2726

@@ -73,7 +72,6 @@ def chat(
7372
runtime variants. Pass them through ``--params`` instead.
7473
"""
7574
import json
76-
import sys
7775

7876
from ai.backend.client.v2.deployment_chat import (
7977
DeploymentChatAuthError,
@@ -148,19 +146,16 @@ async def _run() -> None:
148146
body,
149147
)
150148
except DeploymentChatAuthError as e:
151-
chat_config.clear_token(deployment_id)
152-
save_chat_config(chat_config)
153149
raise click.ClickException(
154150
f"The inference endpoint rejected the configured API key for "
155-
f"deployment {deployment_id}. The cached key has been cleared.\n"
156-
"Register a new one with:\n"
151+
f"deployment {deployment_id}. Re-register with:\n"
157152
f" ./bai deployment chat-config set {deployment_id} --token <api_key>"
158153
) from e
159154
except BackendAPIError as e:
160155
raise click.ClickException(
161156
f"Inference endpoint error ({e.status} {e.reason}): {e.data}"
162157
) from e
163-
sys.stdout.write(json.dumps(response, indent=2, ensure_ascii=False, default=str) + "\n")
158+
print(json.dumps(response, indent=2, ensure_ascii=False, default=str))
164159

165160
_run_async(_run)
166161

src/ai/backend/client/cli/v2/deployment/chat_config.py

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from uuid import UUID
1010

1111
import click
12+
from pydantic import BaseModel, ConfigDict, Field, ValidationError
1213

1314
from ai.backend.client.cli.v2.deployment_chat_cache import (
1415
DeploymentChatCacheEntry,
@@ -25,6 +26,16 @@
2526
from ai.backend.client.cli.v2.helpers import create_v2_registry, load_v2_config
2627

2728

29+
class _ServedModelEntry(BaseModel):
30+
model_config = ConfigDict(extra="allow")
31+
id: str
32+
33+
34+
class _ServedModelsResponse(BaseModel):
35+
model_config = ConfigDict(extra="allow")
36+
data: list[_ServedModelEntry] = Field(default_factory=list)
37+
38+
2839
def _run_async(coro_fn: Callable[[], Awaitable[None]]) -> None:
2940
from ai.backend.client.exceptions import BackendAPIError
3041

@@ -61,12 +72,6 @@ def chat_config() -> None:
6172
"Omit when the runtime was started without an API key."
6273
),
6374
)
64-
@click.option(
65-
"--no-token",
66-
is_flag=True,
67-
default=False,
68-
help="Explicitly clear the cached API key (deployment exposes no auth).",
69-
)
7075
@click.option(
7176
"--default-model",
7277
default=None,
@@ -79,13 +84,9 @@ def chat_config() -> None:
7984
def set_(
8085
deployment_id: UUID,
8186
api_key: str | None,
82-
no_token: bool,
8387
default_model: str | None,
8488
) -> None:
8589
"""Register or update the chat cache entry for a deployment."""
86-
if api_key and no_token:
87-
raise click.ClickException("--token and --no-token are mutually exclusive.")
88-
8990
connection = load_v2_config()
9091
try:
9192
cache = load_chat_cache()
@@ -97,13 +98,7 @@ def set_(
9798
raise click.ClickException(str(e)) from e
9899

99100
existing_entry = cache.get(deployment_id)
100-
resolved_key: str | None
101-
if no_token:
102-
resolved_key = None
103-
elif api_key is not None:
104-
resolved_key = api_key
105-
else:
106-
resolved_key = chat_config_store.get_token(deployment_id)
101+
resolved_key = api_key if api_key is not None else chat_config_store.get_token(deployment_id)
107102

108103
async def _run() -> None:
109104
registry = await create_v2_registry(connection)
@@ -137,16 +132,14 @@ async def _run() -> None:
137132
),
138133
)
139134
save_chat_cache(cache)
140-
if resolved_key is None:
141-
chat_config_store.clear_token(deployment_id)
142-
else:
135+
if resolved_key is not None:
143136
chat_config_store.set_token(deployment_id, resolved_key)
144-
save_chat_config(chat_config_store)
137+
save_chat_config(chat_config_store)
145138

146-
click.echo(f"Updated chat cache entry for deployment {deployment_id}.")
139+
print(f"Updated chat cache entry for deployment {deployment_id}.")
147140
if served_model:
148-
click.echo(f" default_model: {served_model}")
149-
click.echo(f" api_key: {mask_token(resolved_key)}")
141+
print(f" default_model: {served_model}")
142+
print(f" api_key: {mask_token(resolved_key)}")
150143

151144
_run_async(_run)
152145

@@ -174,13 +167,11 @@ async def _discover_model(
174167
payload = await client.list_models(endpoint_url, api_key)
175168
except (DeploymentChatAuthError, BackendAPIError, BackendClientError):
176169
return fallback
177-
data = payload.get("data") if isinstance(payload, dict) else None
178-
if not isinstance(data, list):
170+
try:
171+
parsed = _ServedModelsResponse.model_validate(payload)
172+
except ValidationError:
179173
return fallback
180-
for entry in data:
181-
if isinstance(entry, dict) and isinstance(entry.get("id"), str):
182-
return str(entry["id"])
183-
return fallback
174+
return parsed.data[0].id if parsed.data else fallback
184175

185176

186177
@chat_config.command(name="show")
@@ -206,11 +197,11 @@ def show(deployment_id: UUID | None) -> None:
206197

207198
dep_ids = set(cache.deployments) | set(chat_config_store.tokens)
208199
if not dep_ids:
209-
click.echo("No chat cache entries.")
200+
print("No chat cache entries.")
210201
return
211202
for dep_id in dep_ids:
212203
_print_entry(dep_id, cache.get(dep_id), chat_config_store.get_token(dep_id))
213-
click.echo("")
204+
print()
214205

215206

216207
@chat_config.command(name="clear")
@@ -233,21 +224,25 @@ def clear(deployment_id: UUID) -> None:
233224
if removed_token:
234225
save_chat_config(chat_config_store)
235226
if removed_entry or removed_token:
236-
click.echo(f"Removed chat cache entry for deployment {deployment_id}.")
227+
print(f"Removed chat cache entry for deployment {deployment_id}.")
237228
else:
238-
click.echo(f"No chat cache entry for deployment {deployment_id}.")
229+
print(f"No chat cache entry for deployment {deployment_id}.")
239230

240231

241232
def _print_entry(
242233
deployment_id: UUID,
243234
entry: DeploymentChatCacheEntry | None,
244235
token: str | None,
245236
) -> None:
246-
click.echo(f"deployment_id : {deployment_id}")
247-
click.echo(f"endpoint_url : {entry.endpoint_url if entry else '-'}")
248-
click.echo(f"api_key : {mask_token(token)}")
249-
click.echo(f"default_model : {(entry.default_model if entry else None) or '-'}")
250-
click.echo(f"last_synced_at: {entry.last_synced_at.isoformat() if entry else '-'}")
237+
print(f"deployment_id : {deployment_id}")
238+
if entry is not None:
239+
for line in entry.format_summary():
240+
print(line)
241+
else:
242+
print("endpoint_url : -")
243+
print("default_model : -")
244+
print("last_synced_at: -")
245+
print(f"api_key : {mask_token(token)}")
251246

252247

253248
__all__ = ("chat_config",)

src/ai/backend/client/cli/v2/deployment_chat_cache.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,17 @@ class DeploymentChatCacheEntry(BaseModel):
3535
default_model: str | None = None
3636
last_synced_at: datetime
3737

38+
def format_summary(self) -> list[str]:
39+
return [
40+
f"endpoint_url : {self.endpoint_url}",
41+
f"default_model : {self.default_model or '-'}",
42+
f"last_synced_at: {self.last_synced_at.isoformat()}",
43+
]
44+
3845

3946
class DeploymentChatCache(BaseModel):
4047
"""In-memory representation of the chat cache file."""
4148

42-
schema_version: int = Field(default=CHAT_CACHE_SCHEMA_VERSION)
4349
deployments: dict[UUID, DeploymentChatCacheEntry] = Field(default_factory=dict)
4450

4551
def get(self, deployment_id: UUID) -> DeploymentChatCacheEntry | None:
@@ -93,7 +99,8 @@ def load_chat_cache(path: Path = CHAT_CACHE_FILE) -> DeploymentChatCache:
9399
def save_chat_cache(cache: DeploymentChatCache, path: Path = CHAT_CACHE_FILE) -> None:
94100
"""Atomically write the chat cache."""
95101
path.parent.mkdir(parents=True, exist_ok=True)
96-
payload = cache.model_dump_json(indent=2)
102+
body = {"schema_version": CHAT_CACHE_SCHEMA_VERSION, **cache.model_dump(mode="json")}
103+
payload = json.dumps(body, indent=2, ensure_ascii=False)
97104
fd, tmp_path_str = tempfile.mkstemp(
98105
prefix=path.name + ".",
99106
suffix=".tmp",

src/ai/backend/client/cli/v2/deployment_chat_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
class DeploymentChatConfig(BaseModel):
3030
"""Per-deployment API key registry (user-managed)."""
3131

32-
schema_version: int = Field(default=CHAT_CONFIG_SCHEMA_VERSION)
3332
tokens: dict[UUID, str] = Field(default_factory=dict)
3433

3534
def get_token(self, deployment_id: UUID) -> str | None:
@@ -79,7 +78,8 @@ def load_chat_config(path: Path = CHAT_CONFIG_FILE) -> DeploymentChatConfig:
7978
def save_chat_config(config: DeploymentChatConfig, path: Path = CHAT_CONFIG_FILE) -> None:
8079
"""Atomically write the chat config and enforce ``0600`` permissions."""
8180
path.parent.mkdir(parents=True, exist_ok=True)
82-
payload = config.model_dump_json(indent=2)
81+
body = {"schema_version": CHAT_CONFIG_SCHEMA_VERSION, **config.model_dump(mode="json")}
82+
payload = json.dumps(body, indent=2, ensure_ascii=False)
8383
fd, tmp_path_str = tempfile.mkstemp(
8484
prefix=path.name + ".",
8585
suffix=".tmp",

tests/unit/client/cli/test_deployment_chat_cache.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,20 @@ def test_load_skips_malformed_entry_payload(self, tmp_path: Path) -> None:
140140
assert list(loaded.deployments.keys()) == [good_id]
141141

142142

143+
class TestEntryFormatSummary:
144+
def test_format_summary_returns_lines(self) -> None:
145+
entry = _entry(default_model="meta/test-model")
146+
lines = entry.format_summary()
147+
assert any("endpoint_url" in line for line in lines)
148+
assert any("meta/test-model" in line for line in lines)
149+
assert any("last_synced_at" in line for line in lines)
150+
151+
def test_format_summary_dash_for_missing_default_model(self) -> None:
152+
entry = _entry(default_model=None)
153+
lines = entry.format_summary()
154+
assert any("default_model : -" in line for line in lines)
155+
156+
143157
class TestEntryMutations:
144158
def test_upsert_overwrites_existing_entry(self) -> None:
145159
cache = DeploymentChatCache()

0 commit comments

Comments
 (0)