Skip to content

Commit de5be3d

Browse files
committed
refactor(BA-5528): split chat cache token, switch entry to BaseModel, fix module placement
- DeploymentChatCacheEntry no longer carries api_key. Tokens now live in a separate top-level 'tokens' map in the same JSON cache file, surfaced via DeploymentChatCache.get_token / set_token / clear_token. remove() and chat-config clear/show treat them as one logical record. - Convert DeploymentChatCacheEntry from a dataclass to a Pydantic BaseModel (frozen) so model_dump / model_validate replace the manual to_dict / from_dict helpers. - Rename _resolve_endpoint to _ensure_endpoint_entry in chat.py — the function returns the full cache entry (endpoint + default_model + last_synced_at), not just the URL, so the previous name was misleading. - Move the SDK chat client out of domains_v2/. It does not inherit BaseDomainClient nor call typed_request (it talks directly to the inference container, not the manager) so it does not belong with the REST domain clients. New location: client/v2/deployment_chat.py. - Drop the InferenceChat* rename and keep DeploymentChat* to align with the './bai deployment chat' CLI command name.
1 parent 0ee74b8 commit de5be3d

6 files changed

Lines changed: 169 additions & 115 deletions

File tree

src/ai/backend/client/cli/v2/deployment/chat.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def chat(
7070
import json
7171
import sys
7272

73-
from ai.backend.client.v2.domains_v2.inference_chat import (
74-
InferenceChatAuthError,
75-
InferenceChatClient,
73+
from ai.backend.client.v2.deployment_chat import (
74+
DeploymentChatAuthError,
75+
DeploymentChatClient,
7676
)
7777

7878
config = load_v2_config()
@@ -87,7 +87,7 @@ def chat(
8787
extra_body: dict[str, Any] = params
8888
entry = cache.get(deployment_id)
8989

90-
async def _resolve_endpoint() -> DeploymentChatCacheEntry:
90+
async def _ensure_endpoint_entry() -> DeploymentChatCacheEntry:
9191
if entry is not None and entry.endpoint_url:
9292
return entry
9393
registry = await create_v2_registry(config)
@@ -103,7 +103,6 @@ async def _resolve_endpoint() -> DeploymentChatCacheEntry:
103103
)
104104
new_entry = DeploymentChatCacheEntry(
105105
endpoint_url=endpoint_url,
106-
api_key=entry.api_key if entry is not None else None,
107106
default_model=entry.default_model if entry is not None else None,
108107
last_synced_at=datetime.now(UTC),
109108
)
@@ -114,8 +113,8 @@ async def _resolve_endpoint() -> DeploymentChatCacheEntry:
114113
async def _run() -> None:
115114
from ai.backend.client.exceptions import BackendAPIError
116115

117-
resolved = await _resolve_endpoint()
118-
request_model = model or resolved.default_model
116+
endpoint_entry = await _ensure_endpoint_entry()
117+
request_model = model or endpoint_entry.default_model
119118
if request_model is None:
120119
raise click.ClickException(
121120
f"No --model given and no default_model cached for deployment {deployment_id}.\n"
@@ -129,23 +128,18 @@ async def _run() -> None:
129128
"model": request_model,
130129
"messages": [{"role": "user", "content": content}],
131130
}
132-
async with InferenceChatClient(
131+
api_key = cache.get_token(deployment_id)
132+
async with DeploymentChatClient(
133133
skip_ssl_verification=config.skip_ssl_verification,
134134
) as client:
135135
try:
136136
response = await client.chat_completion(
137-
resolved.endpoint_url,
138-
resolved.api_key,
137+
endpoint_entry.endpoint_url,
138+
api_key,
139139
body,
140140
)
141-
except InferenceChatAuthError as e:
142-
invalidated = DeploymentChatCacheEntry(
143-
endpoint_url=resolved.endpoint_url,
144-
api_key=None,
145-
default_model=resolved.default_model,
146-
last_synced_at=datetime.now(UTC),
147-
)
148-
cache.upsert(deployment_id, invalidated)
141+
except DeploymentChatAuthError as e:
142+
cache.clear_token(deployment_id)
149143
save_chat_cache(cache)
150144
raise click.ClickException(
151145
f"The inference endpoint rejected the configured API key for "

src/ai/backend/client/cli/v2/deployment/chat_config.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def set_(
9494
elif api_key is not None:
9595
resolved_key = api_key
9696
else:
97-
resolved_key = existing.api_key if existing is not None else None
97+
resolved_key = cache.get_token(deployment_id)
9898

9999
async def _run() -> None:
100100
registry = await create_v2_registry(config)
@@ -123,11 +123,14 @@ async def _run() -> None:
123123
deployment_id,
124124
DeploymentChatCacheEntry(
125125
endpoint_url=endpoint_url,
126-
api_key=resolved_key,
127126
default_model=served_model,
128127
last_synced_at=datetime.now(UTC),
129128
),
130129
)
130+
if resolved_key is None:
131+
cache.clear_token(deployment_id)
132+
else:
133+
cache.set_token(deployment_id, resolved_key)
131134
save_chat_cache(cache)
132135
click.echo(f"Updated chat cache entry for deployment {deployment_id}.")
133136
if served_model:
@@ -150,15 +153,15 @@ async def _discover_model(
150153
does not contain any model entries.
151154
"""
152155
from ai.backend.client.exceptions import BackendAPIError, BackendClientError
153-
from ai.backend.client.v2.domains_v2.inference_chat import (
154-
InferenceChatAuthError,
155-
InferenceChatClient,
156+
from ai.backend.client.v2.deployment_chat import (
157+
DeploymentChatAuthError,
158+
DeploymentChatClient,
156159
)
157160

158-
async with InferenceChatClient(skip_ssl_verification=skip_ssl_verification) as client:
161+
async with DeploymentChatClient(skip_ssl_verification=skip_ssl_verification) as client:
159162
try:
160163
payload = await client.list_models(endpoint_url, api_key)
161-
except (InferenceChatAuthError, BackendAPIError, BackendClientError):
164+
except (DeploymentChatAuthError, BackendAPIError, BackendClientError):
162165
return fallback
163166
data = payload.get("data") if isinstance(payload, dict) else None
164167
if not isinstance(data, list):
@@ -180,23 +183,25 @@ def show(deployment_id: UUID | None) -> None:
180183

181184
if deployment_id is not None:
182185
entry = cache.get(deployment_id)
183-
if entry is None:
186+
token = cache.get_token(deployment_id)
187+
if entry is None and token is None:
184188
raise click.ClickException(f"No chat cache entry for deployment {deployment_id}.")
185-
_print_entry(deployment_id, entry)
189+
_print_entry(deployment_id, entry, token)
186190
return
187191

188-
if not cache.entries:
192+
dep_ids = set(cache.entries) | set(cache.tokens)
193+
if not dep_ids:
189194
click.echo("No chat cache entries.")
190195
return
191-
for dep_id, entry in cache.entries.items():
192-
_print_entry(dep_id, entry)
196+
for dep_id in dep_ids:
197+
_print_entry(dep_id, cache.get(dep_id), cache.get_token(dep_id))
193198
click.echo("")
194199

195200

196201
@chat_config.command(name="clear")
197202
@click.argument("deployment_id", type=click.UUID)
198203
def clear(deployment_id: UUID) -> None:
199-
"""Remove the chat cache entry for a deployment."""
204+
"""Remove the chat cache entry and stored token for a deployment."""
200205
try:
201206
cache = load_chat_cache()
202207
except IncompatibleChatCacheError as e:
@@ -208,12 +213,16 @@ def clear(deployment_id: UUID) -> None:
208213
click.echo(f"No chat cache entry for deployment {deployment_id}.")
209214

210215

211-
def _print_entry(deployment_id: UUID, entry: DeploymentChatCacheEntry) -> None:
216+
def _print_entry(
217+
deployment_id: UUID,
218+
entry: DeploymentChatCacheEntry | None,
219+
token: str | None,
220+
) -> None:
212221
click.echo(f"deployment_id : {deployment_id}")
213-
click.echo(f"endpoint_url : {entry.endpoint_url}")
214-
click.echo(f"api_key : {mask_token(entry.api_key)}")
215-
click.echo(f"default_model : {entry.default_model or '-'}")
216-
click.echo(f"last_synced_at: {entry.last_synced_at.isoformat()}")
222+
click.echo(f"endpoint_url : {entry.endpoint_url if entry else '-'}")
223+
click.echo(f"api_key : {mask_token(token)}")
224+
click.echo(f"default_model : {(entry.default_model if entry else None) or '-'}")
225+
click.echo(f"last_synced_at: {entry.last_synced_at.isoformat() if entry else '-'}")
217226

218227

219228
__all__ = ("chat_config",)

src/ai/backend/client/cli/v2/deployment_chat_cache.py

Lines changed: 53 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Local cache for ``./bai deployment chat`` per-deployment settings.
22
3-
Stores ``endpoint_url`` (resolved from the manager) and the inference API
4-
key the user registered for each deployment so that follow-up ``chat``
5-
invocations do not need to re-query the manager nor re-prompt for the key.
6-
7-
Persisted as a single JSON file at ``~/.backend.ai/deployment_chat.json``
8-
with ``0600`` file permissions because the API key is stored in plaintext.
3+
Persists the manager-resolved ``endpoint_url`` and the served model name
4+
discovered from the inference endpoint, plus a separate map of API keys
5+
the user registered through ``./bai deployment chat-config set``. The
6+
endpoint entry is auto-managed (refetched when missing); the token is
7+
user-supplied and never auto-discovered.
8+
9+
Stored as a single JSON file at ``~/.backend.ai/deployment_chat.json``
10+
with ``0600`` permissions because the API keys are kept in plaintext.
911
"""
1012

1113
from __future__ import annotations
@@ -15,56 +17,40 @@
1517
import stat
1618
import tempfile
1719
from dataclasses import dataclass, field
18-
from datetime import UTC, datetime
20+
from datetime import datetime
1921
from pathlib import Path
2022
from typing import Any
2123
from uuid import UUID
2224

25+
from pydantic import BaseModel, ConfigDict, ValidationError
26+
2327
from ai.backend.client.cli.v2.helpers import CONFIG_DIR
2428

2529
CHAT_CACHE_FILE = CONFIG_DIR / "deployment_chat.json"
2630
CHAT_CACHE_SCHEMA_VERSION = 1
2731

2832

29-
@dataclass(frozen=True)
30-
class DeploymentChatCacheEntry:
31-
"""One deployment's chat configuration."""
33+
class DeploymentChatCacheEntry(BaseModel):
34+
"""One deployment's auto-managed endpoint metadata."""
35+
36+
model_config = ConfigDict(frozen=True)
3237

3338
endpoint_url: str
34-
api_key: str | None
35-
default_model: str | None
39+
default_model: str | None = None
3640
last_synced_at: datetime
3741

38-
def to_dict(self) -> dict[str, Any]:
39-
return {
40-
"endpoint_url": self.endpoint_url,
41-
"api_key": self.api_key,
42-
"default_model": self.default_model,
43-
"last_synced_at": self.last_synced_at.isoformat(),
44-
}
45-
46-
@classmethod
47-
def from_dict(cls, data: dict[str, Any]) -> DeploymentChatCacheEntry:
48-
synced_raw = data.get("last_synced_at")
49-
if isinstance(synced_raw, str):
50-
synced = datetime.fromisoformat(synced_raw)
51-
else:
52-
synced = datetime.now(UTC)
53-
return cls(
54-
endpoint_url=str(data["endpoint_url"]),
55-
api_key=(str(data["api_key"]) if data.get("api_key") is not None else None),
56-
default_model=(
57-
str(data["default_model"]) if data.get("default_model") is not None else None
58-
),
59-
last_synced_at=synced,
60-
)
61-
6242

6343
@dataclass
6444
class DeploymentChatCache:
65-
"""In-memory representation of the chat cache file."""
45+
"""In-memory representation of the chat cache file.
46+
47+
``entries`` is the auto-managed endpoint cache; ``tokens`` is the
48+
user-managed API-key store. They are kept in the same file under
49+
distinct top-level keys.
50+
"""
6651

6752
entries: dict[UUID, DeploymentChatCacheEntry] = field(default_factory=dict)
53+
tokens: dict[UUID, str] = field(default_factory=dict)
6854

6955
def get(self, deployment_id: UUID) -> DeploymentChatCacheEntry | None:
7056
return self.entries.get(deployment_id)
@@ -73,12 +59,26 @@ def upsert(self, deployment_id: UUID, entry: DeploymentChatCacheEntry) -> None:
7359
self.entries[deployment_id] = entry
7460

7561
def remove(self, deployment_id: UUID) -> bool:
76-
return self.entries.pop(deployment_id, None) is not None
62+
had_entry = self.entries.pop(deployment_id, None) is not None
63+
had_token = self.tokens.pop(deployment_id, None) is not None
64+
return had_entry or had_token
65+
66+
def get_token(self, deployment_id: UUID) -> str | None:
67+
return self.tokens.get(deployment_id)
68+
69+
def set_token(self, deployment_id: UUID, token: str) -> None:
70+
self.tokens[deployment_id] = token
71+
72+
def clear_token(self, deployment_id: UUID) -> bool:
73+
return self.tokens.pop(deployment_id, None) is not None
7774

7875
def to_dict(self) -> dict[str, Any]:
7976
return {
8077
"schema_version": CHAT_CACHE_SCHEMA_VERSION,
81-
"deployments": {str(dep_id): entry.to_dict() for dep_id, entry in self.entries.items()},
78+
"deployments": {
79+
str(dep_id): entry.model_dump(mode="json") for dep_id, entry in self.entries.items()
80+
},
81+
"tokens": {str(dep_id): token for dep_id, token in self.tokens.items()},
8282
}
8383

8484

@@ -109,8 +109,8 @@ def load_chat_cache(path: Path = CHAT_CACHE_FILE) -> DeploymentChatCache:
109109
f"deployment_chat.json schema version {schema} is newer than supported "
110110
f"{CHAT_CACHE_SCHEMA_VERSION}; please upgrade the client."
111111
)
112-
deployments_raw = raw.get("deployments") or {}
113112
entries: dict[UUID, DeploymentChatCacheEntry] = {}
113+
deployments_raw = raw.get("deployments") or {}
114114
if isinstance(deployments_raw, dict):
115115
for key, value in deployments_raw.items():
116116
try:
@@ -120,10 +120,20 @@ def load_chat_cache(path: Path = CHAT_CACHE_FILE) -> DeploymentChatCache:
120120
if not isinstance(value, dict):
121121
continue
122122
try:
123-
entries[dep_id] = DeploymentChatCacheEntry.from_dict(value)
124-
except (KeyError, ValueError, TypeError):
123+
entries[dep_id] = DeploymentChatCacheEntry.model_validate(value)
124+
except ValidationError:
125+
continue
126+
tokens: dict[UUID, str] = {}
127+
tokens_raw = raw.get("tokens") or {}
128+
if isinstance(tokens_raw, dict):
129+
for key, value in tokens_raw.items():
130+
try:
131+
dep_id = UUID(str(key))
132+
except ValueError:
125133
continue
126-
return DeploymentChatCache(entries=entries)
134+
if isinstance(value, str):
135+
tokens[dep_id] = value
136+
return DeploymentChatCache(entries=entries, tokens=tokens)
127137

128138

129139
def save_chat_cache(cache: DeploymentChatCache, path: Path = CHAT_CACHE_FILE) -> None:

src/ai/backend/client/v2/domains_v2/inference_chat.py renamed to src/ai/backend/client/v2/deployment_chat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
DEFAULT_MODELS_PATH = "/v1/models"
2525

2626

27-
class InferenceChatAuthError(BackendAPIError):
27+
class DeploymentChatAuthError(BackendAPIError):
2828
"""Raised when the inference endpoint rejects the configured API key."""
2929

3030

31-
class InferenceChatClient:
31+
class DeploymentChatClient:
3232
"""Direct HTTP client for OpenAI-compatible inference endpoints."""
3333

3434
_session: aiohttp.ClientSession
@@ -145,7 +145,7 @@ def _raise_for_status(resp: aiohttp.ClientResponse, payload: object) -> None:
145145
return
146146
data = payload if isinstance(payload, dict) else {"detail": payload}
147147
if resp.status in (401, 403):
148-
raise InferenceChatAuthError(resp.status, resp.reason or "Unauthorized", data)
148+
raise DeploymentChatAuthError(resp.status, resp.reason or "Unauthorized", data)
149149
raise BackendAPIError(resp.status, resp.reason or "HTTP error", data)
150150

151151
@staticmethod

0 commit comments

Comments
 (0)