Skip to content

Commit 93d4578

Browse files
committed
refactor(BA-5528): split chat tokens into deployment_chat_config; both stores as BaseModel
- Move user-managed API keys out of DeploymentChatCache and into a new DeploymentChatConfig store backed by ~/.backend.ai/deployment_chat_config.json. The cache is now strictly endpoint metadata that the manager owns; the config is strictly the user-supplied tokens. - Make both DeploymentChatCache and DeploymentChatConfig Pydantic BaseModel subclasses. Drop the manual to_dict helpers in favor of model_dump_json for serialization. Per-entry validation in load functions still skips malformed records individually. - Update chat / chat-config / clear / show to operate on both stores in tandem, preserving the single ./bai deployment chat-config UX.
1 parent de5be3d commit 93d4578

6 files changed

Lines changed: 316 additions & 164 deletions

File tree

src/ai/backend/client/cli/v2/deployment/chat.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
load_chat_cache,
1919
save_chat_cache,
2020
)
21+
from ai.backend.client.cli.v2.deployment_chat_config import (
22+
IncompatibleChatConfigError,
23+
load_chat_config,
24+
save_chat_config,
25+
)
2126
from ai.backend.client.cli.v2.helpers import create_v2_registry, load_v2_config
2227

2328

@@ -75,12 +80,16 @@ def chat(
7580
DeploymentChatClient,
7681
)
7782

78-
config = load_v2_config()
83+
connection = load_v2_config()
7984

8085
try:
8186
cache = load_chat_cache()
8287
except IncompatibleChatCacheError as e:
8388
raise click.ClickException(str(e)) from e
89+
try:
90+
chat_config = load_chat_config()
91+
except IncompatibleChatConfigError as e:
92+
raise click.ClickException(str(e)) from e
8493

8594
if not isinstance(params, dict):
8695
raise click.ClickException("--params must be a JSON object.")
@@ -90,7 +99,7 @@ def chat(
9099
async def _ensure_endpoint_entry() -> DeploymentChatCacheEntry:
91100
if entry is not None and entry.endpoint_url:
92101
return entry
93-
registry = await create_v2_registry(config)
102+
registry = await create_v2_registry(connection)
94103
try:
95104
deployment = await registry.deployment.get(deployment_id)
96105
finally:
@@ -128,9 +137,9 @@ async def _run() -> None:
128137
"model": request_model,
129138
"messages": [{"role": "user", "content": content}],
130139
}
131-
api_key = cache.get_token(deployment_id)
140+
api_key = chat_config.get_token(deployment_id)
132141
async with DeploymentChatClient(
133-
skip_ssl_verification=config.skip_ssl_verification,
142+
skip_ssl_verification=connection.skip_ssl_verification,
134143
) as client:
135144
try:
136145
response = await client.chat_completion(
@@ -139,8 +148,8 @@ async def _run() -> None:
139148
body,
140149
)
141150
except DeploymentChatAuthError as e:
142-
cache.clear_token(deployment_id)
143-
save_chat_cache(cache)
151+
chat_config.clear_token(deployment_id)
152+
save_chat_config(chat_config)
144153
raise click.ClickException(
145154
f"The inference endpoint rejected the configured API key for "
146155
f"deployment {deployment_id}. The cached key has been cleared.\n"

src/ai/backend/client/cli/v2/deployment/chat_config.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,14 @@
1414
DeploymentChatCacheEntry,
1515
IncompatibleChatCacheError,
1616
load_chat_cache,
17-
mask_token,
1817
save_chat_cache,
1918
)
19+
from ai.backend.client.cli.v2.deployment_chat_config import (
20+
IncompatibleChatConfigError,
21+
load_chat_config,
22+
mask_token,
23+
save_chat_config,
24+
)
2025
from ai.backend.client.cli.v2.helpers import create_v2_registry, load_v2_config
2126

2227

@@ -81,23 +86,27 @@ def set_(
8186
if api_key and no_token:
8287
raise click.ClickException("--token and --no-token are mutually exclusive.")
8388

84-
config = load_v2_config()
89+
connection = load_v2_config()
8590
try:
8691
cache = load_chat_cache()
8792
except IncompatibleChatCacheError as e:
8893
raise click.ClickException(str(e)) from e
94+
try:
95+
chat_config_store = load_chat_config()
96+
except IncompatibleChatConfigError as e:
97+
raise click.ClickException(str(e)) from e
8998

90-
existing = cache.get(deployment_id)
99+
existing_entry = cache.get(deployment_id)
91100
resolved_key: str | None
92101
if no_token:
93102
resolved_key = None
94103
elif api_key is not None:
95104
resolved_key = api_key
96105
else:
97-
resolved_key = cache.get_token(deployment_id)
106+
resolved_key = chat_config_store.get_token(deployment_id)
98107

99108
async def _run() -> None:
100-
registry = await create_v2_registry(config)
109+
registry = await create_v2_registry(connection)
101110
try:
102111
deployment = await registry.deployment.get(deployment_id)
103112
finally:
@@ -115,8 +124,8 @@ async def _run() -> None:
115124
served_model = await _discover_model(
116125
endpoint_url,
117126
resolved_key,
118-
config.skip_ssl_verification,
119-
existing.default_model if existing is not None else None,
127+
connection.skip_ssl_verification,
128+
existing_entry.default_model if existing_entry is not None else None,
120129
)
121130

122131
cache.upsert(
@@ -127,11 +136,13 @@ async def _run() -> None:
127136
last_synced_at=datetime.now(UTC),
128137
),
129138
)
139+
save_chat_cache(cache)
130140
if resolved_key is None:
131-
cache.clear_token(deployment_id)
141+
chat_config_store.clear_token(deployment_id)
132142
else:
133-
cache.set_token(deployment_id, resolved_key)
134-
save_chat_cache(cache)
143+
chat_config_store.set_token(deployment_id, resolved_key)
144+
save_chat_config(chat_config_store)
145+
135146
click.echo(f"Updated chat cache entry for deployment {deployment_id}.")
136147
if served_model:
137148
click.echo(f" default_model: {served_model}")
@@ -180,21 +191,25 @@ def show(deployment_id: UUID | None) -> None:
180191
cache = load_chat_cache()
181192
except IncompatibleChatCacheError as e:
182193
raise click.ClickException(str(e)) from e
194+
try:
195+
chat_config_store = load_chat_config()
196+
except IncompatibleChatConfigError as e:
197+
raise click.ClickException(str(e)) from e
183198

184199
if deployment_id is not None:
185200
entry = cache.get(deployment_id)
186-
token = cache.get_token(deployment_id)
201+
token = chat_config_store.get_token(deployment_id)
187202
if entry is None and token is None:
188203
raise click.ClickException(f"No chat cache entry for deployment {deployment_id}.")
189204
_print_entry(deployment_id, entry, token)
190205
return
191206

192-
dep_ids = set(cache.entries) | set(cache.tokens)
207+
dep_ids = set(cache.deployments) | set(chat_config_store.tokens)
193208
if not dep_ids:
194209
click.echo("No chat cache entries.")
195210
return
196211
for dep_id in dep_ids:
197-
_print_entry(dep_id, cache.get(dep_id), cache.get_token(dep_id))
212+
_print_entry(dep_id, cache.get(dep_id), chat_config_store.get_token(dep_id))
198213
click.echo("")
199214

200215

@@ -206,8 +221,18 @@ def clear(deployment_id: UUID) -> None:
206221
cache = load_chat_cache()
207222
except IncompatibleChatCacheError as e:
208223
raise click.ClickException(str(e)) from e
209-
if cache.remove(deployment_id):
224+
try:
225+
chat_config_store = load_chat_config()
226+
except IncompatibleChatConfigError as e:
227+
raise click.ClickException(str(e)) from e
228+
229+
removed_entry = cache.remove(deployment_id)
230+
removed_token = chat_config_store.clear_token(deployment_id)
231+
if removed_entry:
210232
save_chat_cache(cache)
233+
if removed_token:
234+
save_chat_config(chat_config_store)
235+
if removed_entry or removed_token:
211236
click.echo(f"Removed chat cache entry for deployment {deployment_id}.")
212237
else:
213238
click.echo(f"No chat cache entry for deployment {deployment_id}.")
Lines changed: 20 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
"""Local cache for ``./bai deployment chat`` per-deployment settings.
1+
"""Local cache for ``./bai deployment chat`` per-deployment endpoint metadata.
22
3-
Persists the manager-resolved ``endpoint_url`` and the served model name
4-
discovered from the inference endpoint, plus a separate map of API keys
5-
the user registered through ``./bai deployment chat-config set``. The
6-
endpoint entry is auto-managed (refetched when missing); the token is
7-
user-supplied and never auto-discovered.
3+
Stores the manager-resolved ``endpoint_url`` and the served model name
4+
discovered from the inference endpoint. Auto-managed: refetched on cache
5+
miss, never user-edited. The user-supplied API key lives in a separate
6+
file managed by ``deployment_chat_config``.
87
9-
Stored as a single JSON file at ``~/.backend.ai/deployment_chat.json``
10-
with ``0600`` permissions because the API keys are kept in plaintext.
8+
Persisted as a JSON file at ``~/.backend.ai/deployment_chat.json``.
119
"""
1210

1311
from __future__ import annotations
@@ -16,13 +14,11 @@
1614
import os
1715
import stat
1816
import tempfile
19-
from dataclasses import dataclass, field
2017
from datetime import datetime
2118
from pathlib import Path
22-
from typing import Any
2319
from uuid import UUID
2420

25-
from pydantic import BaseModel, ConfigDict, ValidationError
21+
from pydantic import BaseModel, ConfigDict, Field, ValidationError
2622

2723
from ai.backend.client.cli.v2.helpers import CONFIG_DIR
2824

@@ -40,60 +36,28 @@ class DeploymentChatCacheEntry(BaseModel):
4036
last_synced_at: datetime
4137

4238

43-
@dataclass
44-
class DeploymentChatCache:
45-
"""In-memory representation of the chat cache file.
39+
class DeploymentChatCache(BaseModel):
40+
"""In-memory representation of the chat cache file."""
4641

47-
``entries`` is the auto-managed endpoint cache; ``tokens`` is the
48-
user-managed API-key store. They are kept in the same file under
49-
distinct top-level keys.
50-
"""
51-
52-
entries: dict[UUID, DeploymentChatCacheEntry] = field(default_factory=dict)
53-
tokens: dict[UUID, str] = field(default_factory=dict)
42+
schema_version: int = Field(default=CHAT_CACHE_SCHEMA_VERSION)
43+
deployments: dict[UUID, DeploymentChatCacheEntry] = Field(default_factory=dict)
5444

5545
def get(self, deployment_id: UUID) -> DeploymentChatCacheEntry | None:
56-
return self.entries.get(deployment_id)
46+
return self.deployments.get(deployment_id)
5747

5848
def upsert(self, deployment_id: UUID, entry: DeploymentChatCacheEntry) -> None:
59-
self.entries[deployment_id] = entry
49+
self.deployments[deployment_id] = entry
6050

6151
def remove(self, deployment_id: UUID) -> bool:
62-
had_entry = self.entries.pop(deployment_id, None) is not None
63-
had_token = self.tokens.pop(deployment_id, None) is not None
64-
return had_entry or had_token
65-
66-
def get_token(self, deployment_id: UUID) -> str | None:
67-
return self.tokens.get(deployment_id)
68-
69-
def set_token(self, deployment_id: UUID, token: str) -> None:
70-
self.tokens[deployment_id] = token
71-
72-
def clear_token(self, deployment_id: UUID) -> bool:
73-
return self.tokens.pop(deployment_id, None) is not None
74-
75-
def to_dict(self) -> dict[str, Any]:
76-
return {
77-
"schema_version": CHAT_CACHE_SCHEMA_VERSION,
78-
"deployments": {
79-
str(dep_id): entry.model_dump(mode="json") for dep_id, entry in self.entries.items()
80-
},
81-
"tokens": {str(dep_id): token for dep_id, token in self.tokens.items()},
82-
}
52+
return self.deployments.pop(deployment_id, None) is not None
8353

8454

8555
class IncompatibleChatCacheError(Exception):
8656
"""Raised when the on-disk cache file uses a newer schema than this build."""
8757

8858

8959
def load_chat_cache(path: Path = CHAT_CACHE_FILE) -> DeploymentChatCache:
90-
"""Load the chat cache; return an empty cache when the file is absent or unreadable.
91-
92-
A corrupted JSON file or unreadable file is treated as an empty cache —
93-
individual malformed entries are skipped rather than aborting the whole
94-
load. A schema version newer than this build raises
95-
:class:`IncompatibleChatCacheError` so the caller can warn the user.
96-
"""
60+
"""Load the chat cache; return an empty cache when the file is absent or unreadable."""
9761
if not path.exists():
9862
return DeploymentChatCache()
9963
try:
@@ -109,7 +73,7 @@ def load_chat_cache(path: Path = CHAT_CACHE_FILE) -> DeploymentChatCache:
10973
f"deployment_chat.json schema version {schema} is newer than supported "
11074
f"{CHAT_CACHE_SCHEMA_VERSION}; please upgrade the client."
11175
)
112-
entries: dict[UUID, DeploymentChatCacheEntry] = {}
76+
deployments: dict[UUID, DeploymentChatCacheEntry] = {}
11377
deployments_raw = raw.get("deployments") or {}
11478
if isinstance(deployments_raw, dict):
11579
for key, value in deployments_raw.items():
@@ -120,26 +84,16 @@ def load_chat_cache(path: Path = CHAT_CACHE_FILE) -> DeploymentChatCache:
12084
if not isinstance(value, dict):
12185
continue
12286
try:
123-
entries[dep_id] = DeploymentChatCacheEntry.model_validate(value)
87+
deployments[dep_id] = DeploymentChatCacheEntry.model_validate(value)
12488
except ValidationError:
12589
continue
126-
tokens: dict[UUID, str] = {}
127-
tokens_raw = raw.get("tokens") or {}
128-
if isinstance(tokens_raw, dict):
129-
for key, value in tokens_raw.items():
130-
try:
131-
dep_id = UUID(str(key))
132-
except ValueError:
133-
continue
134-
if isinstance(value, str):
135-
tokens[dep_id] = value
136-
return DeploymentChatCache(entries=entries, tokens=tokens)
90+
return DeploymentChatCache(deployments=deployments)
13791

13892

13993
def save_chat_cache(cache: DeploymentChatCache, path: Path = CHAT_CACHE_FILE) -> None:
140-
"""Atomically write the chat cache and enforce ``0600`` permissions."""
94+
"""Atomically write the chat cache."""
14195
path.parent.mkdir(parents=True, exist_ok=True)
142-
payload = json.dumps(cache.to_dict(), indent=2, ensure_ascii=False)
96+
payload = cache.model_dump_json(indent=2)
14397
fd, tmp_path_str = tempfile.mkstemp(
14498
prefix=path.name + ".",
14599
suffix=".tmp",
@@ -155,12 +109,3 @@ def save_chat_cache(cache: DeploymentChatCache, path: Path = CHAT_CACHE_FILE) ->
155109
if tmp_path.exists():
156110
tmp_path.unlink(missing_ok=True)
157111
raise
158-
159-
160-
def mask_token(token: str | None) -> str:
161-
"""Render a token as ``sk-***...***xxxx`` for diagnostic display."""
162-
if token is None:
163-
return "<unset>"
164-
if len(token) <= 8:
165-
return "***"
166-
return f"{token[:3]}***...***{token[-4:]}"

0 commit comments

Comments
 (0)