diff --git a/changes/11344.feature.md b/changes/11344.feature.md new file mode 100644 index 00000000000..9a97f621428 --- /dev/null +++ b/changes/11344.feature.md @@ -0,0 +1 @@ +Add `./bai deployment chat` for one-shot OpenAI-compatible chat against deployed inference services. diff --git a/src/ai/backend/client/cli/v2/deployment/__init__.py b/src/ai/backend/client/cli/v2/deployment/__init__.py index 0bb7aae2207..bf7c9a8575b 100644 --- a/src/ai/backend/client/cli/v2/deployment/__init__.py +++ b/src/ai/backend/client/cli/v2/deployment/__init__.py @@ -1,5 +1,6 @@ from .access_token import access_token from .auto_scaling_rule import auto_scaling_rule +from .chat import chat, chat_cache, chat_config, chat_history from .commands import deployment as deployment from .options import options from .policy import policy @@ -15,5 +16,9 @@ deployment.add_command(access_token) deployment.add_command(auto_scaling_rule) deployment.add_command(options) +deployment.add_command(chat) +deployment.add_command(chat_config) +deployment.add_command(chat_cache) +deployment.add_command(chat_history) __all__ = ("deployment",) diff --git a/src/ai/backend/client/cli/v2/deployment/chat/__init__.py b/src/ai/backend/client/cli/v2/deployment/chat/__init__.py new file mode 100644 index 00000000000..bd909854e30 --- /dev/null +++ b/src/ai/backend/client/cli/v2/deployment/chat/__init__.py @@ -0,0 +1,17 @@ +"""``./bai deployment chat``, ``chat-config``, ``chat-cache``, and ``chat-history`` CLI commands. + +Submodules: +- :mod:`commands` — Click command/group definitions. +- :mod:`types` — Pydantic models for the on-disk cache, config, and history, + including the ``.load()``/``.save()`` classmethods that wire them to + ``~/.backend.ai/deployment_chat/*.json``. +- :mod:`utils` — file paths and shared JSON I/O helpers. +- :mod:`formatter` — display helpers (``mask_token``, ``DeploymentChatFormatter``). + +OpenAI-compat wire DTOs live in :mod:`ai.backend.common.dto.clients.openai_compat` +so they can be reused by any backend.ai component. +""" + +from .commands import chat, chat_cache, chat_config, chat_history + +__all__ = ("chat", "chat_cache", "chat_config", "chat_history") diff --git a/src/ai/backend/client/cli/v2/deployment/chat/commands.py b/src/ai/backend/client/cli/v2/deployment/chat/commands.py new file mode 100644 index 00000000000..6cc5bfb6daf --- /dev/null +++ b/src/ai/backend/client/cli/v2/deployment/chat/commands.py @@ -0,0 +1,417 @@ +"""User-facing CLI: ``./bai deployment chat``, ``chat-config``, ``chat-cache``, ``chat-history``.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Callable, Coroutine +from datetime import UTC, datetime +from typing import Any + +import click + +from ai.backend.cli.params import JSONParamType +from ai.backend.client.cli.v2.deployment.chat.formatter import ( + DeploymentChatFormatter, + mask_token, +) +from ai.backend.client.cli.v2.deployment.chat.types import ( + DEFAULT_CHAT_HISTORY_LIMIT, + DeploymentChatCache, + DeploymentChatCacheEntry, + DeploymentChatConfig, + DeploymentChatHistory, +) +from ai.backend.client.cli.v2.helpers import create_v2_registry, load_v2_config +from ai.backend.common.dto.clients.openai_compat import ChatCompletionRequest +from ai.backend.common.identifier.deployment import DeploymentID + + +def _run_async(coro_fn: Callable[[], Coroutine[Any, Any, None]]) -> None: + from ai.backend.client.exceptions import BackendAPIError + + try: + asyncio.run(coro_fn()) + except BackendAPIError as e: + data: Any = e.args[2] if len(e.args) > 2 else {} + title = data.get("title", "") if isinstance(data, dict) else "" + msg = data.get("msg", "") if isinstance(data, dict) else "" + status = e.args[0] if e.args else "?" + detail = title or msg or str(e) + raise click.ClickException(f"{status}: {detail}") from e + + +# --------------------------------------------------------------------------- +# chat +# --------------------------------------------------------------------------- + + +@click.command(name="chat") +@click.argument("deployment_id", type=click.UUID) +@click.argument("message", type=str) +@click.option( + "--model", + default=None, + type=str, + help=( + "Model name to send. Resolution order: this flag, then the user-set " + "config.json model, then the auto-cached cache.json default_model, " + "then GET /v1/models on the deployment." + ), +) +@click.option( + "--params", + default="{}", + type=JSONParamType(), + help=( + "Extra request-body fields as a JSON object. " + "Forwarded to the inference endpoint as-is " + '(e.g. \'{"temperature": 0.7, "max_tokens": 256}\'). ' + "The 'model' and 'messages' fields are always overridden by --model and MESSAGE." + ), +) +@click.option( + "--history-limit", + default=DEFAULT_CHAT_HISTORY_LIMIT, + type=click.IntRange(min=0), + show_default=True, + help=( + "Maximum number of past messages from this deployment's persisted " + "history to replay as context. Set to 0 to skip context for this " + "turn (the round is still recorded; use `chat-history clear` to " + "wipe the persisted transcript)." + ), +) +def chat( + deployment_id: DeploymentID, + message: str, + model: str | None, + params: Any, + history_limit: int, +) -> None: + """Send a one-shot chat completion request to a deployed model. + + Targets OpenAI-compatible Chat Completions endpoints + (vLLM / SGLang / NIM / TGI in messages-api mode / custom containers + that follow the same contract). Sampling parameters such as + temperature and top_p differ between runtime variants — pass them + through ``--params``. + """ + from ai.backend.client.v2.config import ClientConfig + from ai.backend.client.v2.deployment_chat import DeploymentChatClient + from ai.backend.client.v2.exceptions import DeploymentAuthError + + connection_config = load_v2_config() + + cache = DeploymentChatCache.load() + chat_config = DeploymentChatConfig.load() + history = DeploymentChatHistory.load() + + if not isinstance(params, dict): + raise click.ClickException("--params must be a JSON object.") + extra_body: dict[str, Any] = params + + async def _run() -> None: + existing = cache.get(deployment_id) + if existing is not None and not existing.is_expired(now=datetime.now(UTC)): + endpoint_entry = existing + else: + registry = await create_v2_registry(connection_config) + try: + deployment = await registry.deployment.get(deployment_id) + finally: + await registry.close() + endpoint_url = deployment.network_access.endpoint_url + if not endpoint_url: + raise click.ClickException( + f"Deployment {deployment_id} has no endpoint_url yet " + "(it may still be provisioning). Wait until the deployment is READY." + ) + endpoint_entry = DeploymentChatCacheEntry( + endpoint_url=endpoint_url, + default_model=existing.default_model if existing is not None else None, + last_synced_at=datetime.now(UTC), + ) + cache.set(deployment_id, endpoint_entry) + cache.save() + + token = chat_config.get_token(deployment_id) + # ``endpoint`` is required on ClientConfig but unused by AppProxyClient + # (deployment URLs are passed per-request); pass through the manager + # endpoint so the rest of the connection knobs (TLS, timeouts) match. + client_config = ClientConfig( + endpoint=connection_config.endpoint, + endpoint_type=connection_config.endpoint_type, + api_version=connection_config.api_version, + skip_ssl_verification=connection_config.skip_ssl_verification, + ) + async with DeploymentChatClient(client_config) as client: + try: + # Resolution: --model > config.model (user-set) > + # cache.default_model (auto) > GET /v1/models (auto, cached). + request_model = ( + model or chat_config.get_model(deployment_id) or endpoint_entry.default_model + ) + if request_model is None: + # No explicit --model, no user-set config, no cached + # default — ask the OpenAI-compat endpoint itself which + # models it serves and adopt the first one as the + # cached default (matches webui ChatCard.tsx behaviour). + models_response = await client.list_models(endpoint_entry.endpoint_url, token) + if not models_response.data: + raise click.ClickException( + f"Deployment {deployment_id} did not advertise any models " + f"on /v1/models. Set one explicitly with:\n" + f" ./bai deployment chat-config set {deployment_id} " + f"--model " + ) + request_model = models_response.data[0].id + endpoint_entry = DeploymentChatCacheEntry( + endpoint_url=endpoint_entry.endpoint_url, + default_model=request_model, + last_synced_at=endpoint_entry.last_synced_at, + ) + cache.set(deployment_id, endpoint_entry) + cache.save() + + past_messages = history.slice(deployment_id, history_limit) + request_messages: list[dict[str, str]] = [ + *({"role": past.role, "content": past.content} for past in past_messages), + {"role": "user", "content": message}, + ] + request = ChatCompletionRequest.model_validate({ + **extra_body, + "model": request_model, + "messages": request_messages, + }) + body = request.model_dump(mode="json") + response = await client.chat_completion(endpoint_entry.endpoint_url, token, body) + except DeploymentAuthError as e: + # 401/403 from /v1/models or /v1/chat/completions: invalidate + # the cached token so the next ``chat`` call surfaces the same + # hint instead of silently re-sending a stale key. Other + # BackendAPIErrors fall through to ``_run_async`` which formats + # the manager-style status/title/msg payload generically. + if token is not None and chat_config.clear_token(deployment_id): + chat_config.save() + raise click.ClickException( + f"The inference endpoint rejected the configured token for " + f"deployment {deployment_id}. The stored token has been cleared; " + f"re-register with:\n" + f" ./bai deployment chat-config set {deployment_id} --token " + ) from e + # Only persist when both sides of the round are present, so the file + # never carries half-conversations that would skew future context. + assistant_message = response.assistant_message + if assistant_message is not None: + now = datetime.now(UTC) + history.append(deployment_id, "user", message, created_at=now) + history.append(deployment_id, "assistant", assistant_message, created_at=now) + history.save() + print(response.model_dump_json(indent=2)) + + _run_async(_run) + + +# --------------------------------------------------------------------------- +# chat-config +# --------------------------------------------------------------------------- + + +@click.group(name="chat-config") +def chat_config() -> None: + """Manage user-supplied chat config (Bearer token, chosen model) per + deployment. + + The deployment's ``endpoint_url`` and the auto-derived + ``default_model`` from ``GET /v1/models`` live in the cache file and + are managed by ``./bai deployment chat`` itself; this group only + edits the user-managed config file (``~/.backend.ai/deployment_chat/ + config.json``). + """ + + +@chat_config.command(name="set") +@click.argument("deployment_id", type=click.UUID) +@click.option( + "--token", + default=None, + type=str, + help=( + "Token the inference runtime accepts as a Bearer credential. " + "Omit when the deployment is open to public." + ), +) +@click.option( + "--model", + default=None, + type=str, + help=( + "Model name to use for ``chat`` calls on this deployment. " + "Takes precedence over the auto-derived default_model in cache.json." + ), +) +def set_( + deployment_id: DeploymentID, + token: str | None, + model: str | None, +) -> None: + """Register or update the chat config for a deployment. + + Writes only to ``config.json`` — the manager is not contacted, so this + works regardless of deployment provisioning state and stays usable + offline. + """ + if token is None and model is None: + raise click.ClickException("Nothing to set: provide --token and/or --model.") + + config = DeploymentChatConfig.load() + if token is not None: + config.set_token(deployment_id, token) + if model is not None: + config.set_model(deployment_id, model) + config.save() + + print(f"Updated chat config for deployment {deployment_id}.") + if model is not None: + print(f" model: {model}") + if token is not None: + print(f" token: {mask_token(token)}") + + +@chat_config.command(name="show") +@click.argument("deployment_id", type=click.UUID) +def show(deployment_id: DeploymentID) -> None: + """Print the user-managed chat config entry for a deployment (tokens are masked). + + Only the user-managed fields (``token``, ``model``) are shown; the + auto-managed cache (``endpoint_url``, ``default_model``, + ``last_synced_at``) is treated as internal CLI state and not part of + this view. + """ + config_entry = DeploymentChatConfig.load().get(deployment_id) + if config_entry is None: + raise click.ClickException(f"No chat config for deployment {deployment_id}.") + DeploymentChatFormatter.print_config(deployment_id, config_entry) + + +@chat_config.command(name="clear") +@click.argument("deployment_id", type=click.UUID) +def clear(deployment_id: DeploymentID) -> None: + """Remove the user-managed config entry (token + model) for a deployment. + + The auto-managed cache entry (``endpoint_url``, ``default_model``, + ``last_synced_at``) is left alone — it expires on its own 24-hour TTL + and gets refreshed by the next ``chat`` call. Use + ``./bai deployment chat-cache clear`` to drop it immediately. + """ + config = DeploymentChatConfig.load() + if config.delete(deployment_id): + config.save() + print(f"Removed config entry for deployment {deployment_id}.") + else: + print(f"No config entry for deployment {deployment_id}.") + + +# --------------------------------------------------------------------------- +# chat-cache +# --------------------------------------------------------------------------- + + +@click.group(name="chat-cache") +def chat_cache() -> None: + """Inspect or drop the auto-managed chat cache entry per deployment. + + The cache stores values the CLI derived itself — the deployment's + ``endpoint_url`` (resolved from the manager) and the inferred + ``default_model`` (from ``GET /v1/models``) — under + ``~/.backend.ai/deployment_chat/cache.json`` with a 24-hour TTL. + User-supplied state (``token``, ``model``) lives in ``chat-config`` + and is not touched by this group. + """ + + +@chat_cache.command(name="show") +@click.argument("deployment_id", type=click.UUID) +def cache_show(deployment_id: DeploymentID) -> None: + """Print the auto-managed chat cache entry for a deployment.""" + entry = DeploymentChatCache.load().get(deployment_id) + if entry is None: + raise click.ClickException(f"No chat cache for deployment {deployment_id}.") + DeploymentChatFormatter.print_cache(deployment_id, entry) + + +@chat_cache.command(name="clear") +@click.argument("deployment_id", type=click.UUID) +def cache_clear(deployment_id: DeploymentID) -> None: + """Remove the auto-managed cache entry for a deployment. + + Forces the next ``chat`` call to re-fetch ``endpoint_url`` from the + manager and re-derive ``default_model`` from ``GET /v1/models``. The + user-managed config entry (token + model) is left alone. + """ + cache = DeploymentChatCache.load() + if cache.delete(deployment_id): + cache.save() + print(f"Removed cache entry for deployment {deployment_id}.") + else: + print(f"No cache entry for deployment {deployment_id}.") + + +# --------------------------------------------------------------------------- +# chat-history +# --------------------------------------------------------------------------- + + +@click.group(name="chat-history") +def chat_history() -> None: + """Manage per-deployment chat transcripts. + + The ``chat`` command auto-records each user/assistant round into + ``~/.backend.ai/deployment_chat/history.json`` so subsequent calls + can replay recent turns as context. Use this group to inspect or + wipe what has been persisted. + """ + + +@chat_history.command(name="show") +@click.argument("deployment_id", type=click.UUID) +@click.option( + "--limit", + default=None, + type=click.IntRange(min=1), + help="Print only the most recent N messages (default: all persisted).", +) +def history_show(deployment_id: DeploymentID, limit: int | None) -> None: + """Print the persisted transcript for a deployment.""" + history = DeploymentChatHistory.load() + messages = history.get(deployment_id) + if messages is None: + print(f"No chat history for deployment {deployment_id}.") + return + if not messages: + print(f"Chat history for deployment {deployment_id} is empty.") + return + visible = messages if limit is None else messages[-limit:] + print(f"deployment_id : {deployment_id}") + print(f"messages : {len(messages)} persisted (showing {len(visible)})") + for message in visible: + print(f" [{message.created_at.isoformat()}] {message.role}: {message.content}") + + +@chat_history.command(name="clear") +@click.argument("deployment_id", type=click.UUID) +def history_clear(deployment_id: DeploymentID) -> None: + """Drop the persisted transcript for a deployment. + + The next ``chat`` call starts a fresh context. Cache and config + entries are unaffected. + """ + history = DeploymentChatHistory.load() + if history.clear(deployment_id): + history.save() + print(f"Cleared chat history for deployment {deployment_id}.") + else: + print(f"No chat history for deployment {deployment_id}.") + + +__all__ = ("chat", "chat_cache", "chat_config", "chat_history") diff --git a/src/ai/backend/client/cli/v2/deployment/chat/formatter.py b/src/ai/backend/client/cli/v2/deployment/chat/formatter.py new file mode 100644 index 00000000000..cbb21bc7ef6 --- /dev/null +++ b/src/ai/backend/client/cli/v2/deployment/chat/formatter.py @@ -0,0 +1,49 @@ +"""Display helpers for ``./bai deployment chat-config show`` / ``chat-cache show``.""" + +from __future__ import annotations + +from ai.backend.client.cli.v2.deployment.chat.types import ( + DeploymentChatCacheEntry, + DeploymentChatConfigEntry, +) +from ai.backend.common.identifier.deployment import DeploymentID + + +def mask_token(token: str | None) -> str: + """Render a stored token as a fixed placeholder for diagnostic display. + + The placeholder is length-independent so the masked output never leaks the + token's length, prefix, or suffix. + """ + if token is None: + return "" + return "********" + + +class DeploymentChatFormatter: + """Formatting helpers for the chat config and cache entries. + + Formatting and rendering live here rather than on the data classes so + the data model stays free of presentation concerns. + """ + + @classmethod + def print_config( + cls, + deployment_id: DeploymentID, + entry: DeploymentChatConfigEntry, + ) -> None: + print(f"deployment_id : {deployment_id}") + print(f"model : {entry.model or '-'}") + print(f"token : {mask_token(entry.token)}") + + @classmethod + def print_cache( + cls, + deployment_id: DeploymentID, + entry: DeploymentChatCacheEntry, + ) -> None: + print(f"deployment_id : {deployment_id}") + print(f"endpoint_url : {entry.endpoint_url}") + print(f"default_model : {entry.default_model or '-'}") + print(f"last_synced_at: {entry.last_synced_at.isoformat()}") diff --git a/src/ai/backend/client/cli/v2/deployment/chat/types.py b/src/ai/backend/client/cli/v2/deployment/chat/types.py new file mode 100644 index 00000000000..163fdb089e3 --- /dev/null +++ b/src/ai/backend/client/cli/v2/deployment/chat/types.py @@ -0,0 +1,255 @@ +"""Type definitions for ``./bai deployment chat`` storage.""" + +from __future__ import annotations + +import sys +from collections import defaultdict +from datetime import datetime, timedelta +from typing import Annotated, Self + +from pydantic import BaseModel, Field, ValidationError + +from ai.backend.client.cli.v2.deployment.chat.utils import ( + CHAT_CACHE_FILE, + CHAT_CONFIG_FILE, + CHAT_HISTORY_FILE, + read_json_file, + write_json_file, +) +from ai.backend.common.identifier.deployment import DeploymentID + +CACHE_ENTRY_TTL = timedelta(hours=24) +"""Endpoint cache entries older than this are treated as a cache miss.""" + +DEFAULT_CHAT_HISTORY_LIMIT = 10 +"""Default number of past messages forwarded as context on each ``chat`` call. + +Mirrors the typical 5-turn rolling window of OpenAI-compatible chat UIs. +Override per-call with ``--history-limit``; setting it to 0 disables context. +""" + +MAX_PERSISTED_HISTORY_MESSAGES = 100 +"""Hard cap on messages kept in ``history.json`` per deployment. + +The file holds plain text; capping it keeps disk usage bounded even when the +user never runs ``chat-history clear``. Older messages are dropped FIFO. +""" + + +class DeploymentChatCacheEntry(BaseModel): + """One deployment's auto-managed endpoint metadata.""" + + endpoint_url: str + default_model: str | None = None + last_synced_at: datetime + + def is_expired(self, *, now: datetime, ttl: timedelta = CACHE_ENTRY_TTL) -> bool: + """Whether this entry is older than the cache TTL window.""" + return now - self.last_synced_at >= ttl + + +class DeploymentChatCache(BaseModel): + """In-memory representation of the chat cache file.""" + + deployments: dict[DeploymentID, DeploymentChatCacheEntry] = Field(default_factory=dict) + + def get(self, deployment_id: DeploymentID) -> DeploymentChatCacheEntry | None: + return self.deployments.get(deployment_id) + + def set(self, deployment_id: DeploymentID, entry: DeploymentChatCacheEntry) -> None: + self.deployments[deployment_id] = entry + + def delete(self, deployment_id: DeploymentID) -> bool: + """Remove the cache entry for ``deployment_id``; return True when an entry was removed.""" + return self.deployments.pop(deployment_id, None) is not None + + @classmethod + def load(cls) -> Self: + """Load the chat cache; return an empty cache when the file is absent or unreadable.""" + raw = read_json_file(CHAT_CACHE_FILE) + if raw is None: + return cls() + try: + return cls.model_validate(raw) + except ValidationError: + print( + f"WARNING: {CHAT_CACHE_FILE} is in an invalid format and was ignored.", + file=sys.stderr, + ) + return cls() + + def save(self) -> None: + """Persist the cache as a plain JSON file (matches existing CLI credential + storage convention; see ``client/cli/v2/config_cmd.py``). + """ + write_json_file(CHAT_CACHE_FILE, self.model_dump_json(indent=2)) + + +class DeploymentChatConfigEntry(BaseModel): + """One deployment's user-managed state. + + ``model`` holds the user's explicit ``--model`` choice for a deployment; + it takes precedence over :attr:`DeploymentChatCacheEntry.default_model` + (which is the value the CLI auto-derived from ``GET /v1/models``). + """ + + token: str | None = None + model: str | None = None + + def is_empty(self) -> bool: + return self.token is None and self.model is None + + +class DeploymentChatConfig(BaseModel): + """Per-deployment user-managed registry (tokens + chosen model name).""" + + deployments: defaultdict[ + DeploymentID, + Annotated[DeploymentChatConfigEntry, Field(default_factory=DeploymentChatConfigEntry)], + ] = Field(default_factory=lambda: defaultdict(DeploymentChatConfigEntry)) + + def get(self, deployment_id: DeploymentID) -> DeploymentChatConfigEntry | None: + return self.deployments.get(deployment_id) + + def get_token(self, deployment_id: DeploymentID) -> str | None: + entry = self.deployments.get(deployment_id) + return entry.token if entry is not None else None + + def get_model(self, deployment_id: DeploymentID) -> str | None: + entry = self.deployments.get(deployment_id) + return entry.model if entry is not None else None + + def set_token(self, deployment_id: DeploymentID, token: str) -> None: + self.deployments[deployment_id].token = token + + def set_model(self, deployment_id: DeploymentID, model: str) -> None: + self.deployments[deployment_id].model = model + + def clear_token(self, deployment_id: DeploymentID) -> bool: + """Null the token field for ``deployment_id``; return True when a token was cleared. + + Drops the entry entirely if both ``token`` and ``model`` end up unset. + """ + entry = self.deployments.get(deployment_id) + if entry is None or entry.token is None: + return False + entry.token = None + if entry.is_empty(): + del self.deployments[deployment_id] + return True + + def clear_model(self, deployment_id: DeploymentID) -> bool: + """Null the model field for ``deployment_id``; return True when a model was cleared. + + Drops the entry entirely if both ``token`` and ``model`` end up unset. + """ + entry = self.deployments.get(deployment_id) + if entry is None or entry.model is None: + return False + entry.model = None + if entry.is_empty(): + del self.deployments[deployment_id] + return True + + def delete(self, deployment_id: DeploymentID) -> bool: + """Remove the config entry for ``deployment_id``; return True when an entry was removed.""" + return self.deployments.pop(deployment_id, None) is not None + + @classmethod + def load(cls) -> Self: + """Load the chat config; return an empty config when the file is absent or unreadable.""" + raw = read_json_file(CHAT_CONFIG_FILE) + if raw is None: + return cls() + try: + return cls.model_validate(raw) + except ValidationError: + print( + f"WARNING: {CHAT_CONFIG_FILE} is in an invalid format and was ignored. " + "Re-register tokens with `./bai deployment chat-config set`.", + file=sys.stderr, + ) + return cls() + + def save(self) -> None: + """Persist the config as a plain JSON file (matches existing CLI credential + storage convention; see ``client/cli/v2/config_cmd.py``). + """ + write_json_file(CHAT_CONFIG_FILE, self.model_dump_json(indent=2)) + + +class ChatMessage(BaseModel): + """One persisted user/assistant turn. + + ``created_at`` is local-only metadata for ``chat-history show``; it is + stripped before the message is replayed into the chat-completions request + body (the wire format is just ``{role, content}``). + """ + + role: str + content: str + created_at: datetime + + +class DeploymentChatHistory(BaseModel): + """Per-deployment rolling chat transcripts. + + Stored separately from the cache (auto-managed endpoint metadata) and + config (user-managed token/model) so that clearing one does not affect + the others. The transcripts are FIFO-truncated at + :data:`MAX_PERSISTED_HISTORY_MESSAGES` to bound disk usage. + """ + + deployments: dict[DeploymentID, list[ChatMessage]] = Field(default_factory=dict) + + def get(self, deployment_id: DeploymentID) -> list[ChatMessage] | None: + return self.deployments.get(deployment_id) + + def slice(self, deployment_id: DeploymentID, limit: int) -> list[ChatMessage]: + """Return the last ``limit`` turns of the transcript for replay as request context.""" + if limit <= 0: + return [] + messages = self.deployments.get(deployment_id) + if not messages: + return [] + return messages[-limit:] + + def append( + self, + deployment_id: DeploymentID, + role: str, + content: str, + *, + created_at: datetime, + max_persisted: int = MAX_PERSISTED_HISTORY_MESSAGES, + ) -> None: + """Append one turn and FIFO-truncate to keep the file bounded.""" + messages = self.deployments.setdefault(deployment_id, []) + messages.append(ChatMessage(role=role, content=content, created_at=created_at)) + overflow = len(messages) - max_persisted + if overflow > 0: + del messages[:overflow] + + def clear(self, deployment_id: DeploymentID) -> bool: + return self.deployments.pop(deployment_id, None) is not None + + @classmethod + def load(cls) -> Self: + """Load the chat history; return an empty history when the file is absent or unreadable.""" + raw = read_json_file(CHAT_HISTORY_FILE) + if raw is None: + return cls() + try: + return cls.model_validate(raw) + except ValidationError: + print( + f"WARNING: {CHAT_HISTORY_FILE} is in an invalid format and was ignored.", + file=sys.stderr, + ) + return cls() + + def save(self) -> None: + """Persist the history as a plain JSON file (matches existing CLI credential + storage convention; see ``client/cli/v2/config_cmd.py``). + """ + write_json_file(CHAT_HISTORY_FILE, self.model_dump_json(indent=2)) diff --git a/src/ai/backend/client/cli/v2/deployment/chat/utils.py b/src/ai/backend/client/cli/v2/deployment/chat/utils.py new file mode 100644 index 00000000000..8d3390c1944 --- /dev/null +++ b/src/ai/backend/client/cli/v2/deployment/chat/utils.py @@ -0,0 +1,43 @@ +"""Filesystem helpers for ``./bai deployment chat`` storage. + +Three on-disk JSON files live under ``~/.backend.ai/deployment_chat/`` — +``cache.json`` for the auto-managed endpoint cache, ``config.json`` for +user-supplied tokens, and ``history.json`` for per-deployment chat +transcripts. The per-feature subdirectory matches the existing +``session/`` layout used by ``./bai login``. + +All three files are written as plain JSON (no atomic-rename, no special +POSIX permissions) to stay aligned with the existing CLI credential-storage +convention in :mod:`ai.backend.client.cli.v2.config_cmd`. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from ai.backend.client.cli.v2.helpers import CONFIG_DIR +from ai.backend.common.json import load_json + +CHAT_DIR = CONFIG_DIR / "deployment_chat" +CHAT_CACHE_FILE = CHAT_DIR / "cache.json" +CHAT_CONFIG_FILE = CHAT_DIR / "config.json" +CHAT_HISTORY_FILE = CHAT_DIR / "history.json" + + +def read_json_file(path: Path) -> dict[str, Any] | None: + """Read a JSON file as a dict, returning None on missing or unparseable input.""" + if not path.exists(): + return None + try: + with path.open("rb") as f: + raw = load_json(f.read()) + except (OSError, ValueError): + return None + return raw if isinstance(raw, dict) else None + + +def write_json_file(path: Path, text: str) -> None: + """Write ``text`` to ``path`` after ensuring the parent directory exists.""" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(text, encoding="utf-8") diff --git a/src/ai/backend/client/v2/base_client.py b/src/ai/backend/client/v2/base_client.py index a1f4481d547..5c3fc8192c5 100644 --- a/src/ai/backend/client/v2/base_client.py +++ b/src/ai/backend/client/v2/base_client.py @@ -1,14 +1,17 @@ from __future__ import annotations +import json from collections.abc import AsyncIterator, Iterable, Mapping from contextlib import asynccontextmanager from datetime import UTC, datetime -from typing import Any, TypeVar, cast +from types import TracebackType +from typing import Any, Self, TypeVar, cast import aiohttp from multidict import CIMultiDict +from yarl import URL -from ai.backend.client.exceptions import BackendAPIError +from ai.backend.client.exceptions import BackendAPIError, BackendClientError from ai.backend.common.api_handlers import ( BaseRequestModel, BaseResponseModel, @@ -17,7 +20,12 @@ from .auth import AuthStrategy from .config import ClientConfig -from .exceptions import SSEError, WebSocketError, map_status_to_exception +from .exceptions import ( + DeploymentAuthError, + SSEError, + WebSocketError, + map_status_to_exception, +) from .streaming_types import SSEConnection, WebSocketSession ResponseT = TypeVar("ResponseT", bound=BaseResponseModel | BaseRootResponseModel[Any]) @@ -453,3 +461,104 @@ async def typed_request( }, ) return cast(ResponseT, response_model.model_validate(data)) + + +class BackendAIAppProxyClient: + """HTTP client base for direct-to-deployment endpoints fronted by Backend.AI's app-proxy. + + Unlike :class:`BackendAIAuthClient` (which signs requests with HMAC against + the Backend.AI manager API), this client targets the runtime's own HTTP + surface (vLLM / SGLang / NIM / TGI / custom) and uses an optional + ``Authorization: Bearer `` header. The deployment endpoint URL is + supplied per-request, not via :attr:`ClientConfig.endpoint`. + + Subclasses add the contract-specific request methods (e.g. chat-completions, + /generate, etc.). + """ + + _config: ClientConfig + _session: aiohttp.ClientSession + + def __init__(self, config: ClientConfig) -> None: + self._config = config + self._session = _create_aiohttp_session(config) + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + _exc_type: type[BaseException] | None, + _exc: BaseException | None, + _tb: TracebackType | None, + ) -> None: + await self.close() + + async def close(self) -> None: + if not self._session.closed: + await self._session.close() + + async def _request( + self, + method: str, + endpoint_url: str, + path: str, + token: str | None, + *, + body: dict[str, Any] | None = None, + ) -> dict[str, Any]: + target = self._build_url(endpoint_url, path) + headers: dict[str, str] = {} + if body is not None: + headers["Content-Type"] = "application/json" + if token: + headers["Authorization"] = f"Bearer {token}" + try: + async with self._session.request(method, target, headers=headers, json=body) as resp: + payload = await self._parse_response(resp) + self._raise_for_status(resp, payload) + return payload + except aiohttp.ClientConnectionError as e: + raise BackendClientError(f"failed to reach deployment endpoint: {e!r}") from e + + @staticmethod + async def _parse_response(resp: aiohttp.ClientResponse) -> dict[str, Any]: + # Backend.AI's app-proxy fronts every deployment endpoint, and on + # 5xx it can emit HTML / plain-text bodies (e.g. cloud LB error + # pages) instead of JSON. Read text up front so the raw body is + # available either as the JSON-parse input or as context in the + # raised error. + raw = await resp.text() + try: + payload = json.loads(raw) if raw else None + except json.JSONDecodeError as e: + if resp.status >= 400: + raise BackendAPIError( + resp.status, resp.reason or "HTTP error", {"detail": raw} + ) from e + raise BackendClientError( + f"deployment endpoint returned non-JSON response (status={resp.status}): {raw!r}" + ) from e + if not isinstance(payload, dict): + raise BackendClientError( + f"deployment endpoint returned non-object payload " + f"(type={type(payload).__name__}, status={resp.status}): {payload!r}" + ) + return payload + + @staticmethod + def _build_url(endpoint_url: str, path: str) -> str: + base = URL(endpoint_url) + target_path = path if path.startswith("/") else "/" + path + base_path = base.path.rstrip("/") + if base_path.endswith(target_path): + return str(base.with_path(base_path)) + return str(base.with_path(f"{base_path}{target_path}")) + + @staticmethod + def _raise_for_status(resp: aiohttp.ClientResponse, payload: dict[str, Any]) -> None: + if resp.status < 400: + return + if resp.status in (401, 403): + raise DeploymentAuthError(resp.status, resp.reason or "Unauthorized", payload) + raise BackendAPIError(resp.status, resp.reason or "HTTP error", payload) diff --git a/src/ai/backend/client/v2/deployment_chat.py b/src/ai/backend/client/v2/deployment_chat.py new file mode 100644 index 00000000000..9e0c96d3c5f --- /dev/null +++ b/src/ai/backend/client/v2/deployment_chat.py @@ -0,0 +1,61 @@ +"""SDK client for OpenAI Chat Completions deployment endpoints. + +:class:`DeploymentChatClient` targets endpoints that follow the OpenAI HTTP +contract (``POST /v1/chat/completions`` with an ``{model, messages, ...}`` +JSON body): vLLM, SGLang, NVIDIA NIM, and TGI in Messages API mode. Vanilla +TGI's native ``/generate`` and arbitrary custom containers are out of scope. + +Wire DTOs (``ChatCompletionRequest``, ``ListModelsResponse``, etc.) live in +:mod:`ai.backend.common.dto.clients.openai_compat` so other components can +reuse them. Session lifecycle, JSON parsing, URL normalization, and +401/403 → auth-error mapping live on :class:`BackendAIAppProxyClient` in +:mod:`ai.backend.client.v2.base_client`. +""" + +from __future__ import annotations + +from typing import Any + +from ai.backend.client.v2.base_client import BackendAIAppProxyClient +from ai.backend.common.dto.clients.openai_compat import ( + ChatCompletionResponse, + ListModelsResponse, +) + +_OPENAI_COMPATIBLE_CHAT_PATH = "/v1/chat/completions" +_OPENAI_COMPATIBLE_MODELS_PATH = "/v1/models" + + +class DeploymentChatClient(BackendAIAppProxyClient): + """OpenAI Chat Completions client for direct-to-deployment inference traffic. + + Sends ``POST /v1/chat/completions`` with an OpenAI-shaped + ``{model, messages, ...}`` JSON body. Compatible runtimes: vLLM, + SGLang, NVIDIA NIM, and TGI in Messages API mode. Vanilla TGI + (``/generate``) and arbitrary custom containers need a different + :class:`BackendAIAppProxyClient` subclass. + """ + + async def chat_completion( + self, + endpoint_url: str, + token: str | None, + body: dict[str, Any], + ) -> ChatCompletionResponse: + payload = await self._request( + "POST", endpoint_url, _OPENAI_COMPATIBLE_CHAT_PATH, token, body=body + ) + return ChatCompletionResponse.model_validate(payload) + + async def list_models( + self, + endpoint_url: str, + token: str | None, + ) -> ListModelsResponse: + """Fetch ``GET /v1/models`` — the OpenAI-compat model listing. + + Used to auto-derive a default model name when the caller did not + pass ``--model`` and no cached default is known. + """ + payload = await self._request("GET", endpoint_url, _OPENAI_COMPATIBLE_MODELS_PATH, token) + return ListModelsResponse.model_validate(payload) diff --git a/src/ai/backend/client/v2/exceptions.py b/src/ai/backend/client/v2/exceptions.py index 105031f77b6..0099bf50192 100644 --- a/src/ai/backend/client/v2/exceptions.py +++ b/src/ai/backend/client/v2/exceptions.py @@ -41,6 +41,11 @@ class ServerError(BackendAPIError): } +class DeploymentAuthError(BackendAPIError): + """Raised when a deployment's inference endpoint (fronted by app-proxy) + rejects the configured token with HTTP 401/403.""" + + class WebSocketError(BackendClientError): """Error during WebSocket connection or communication.""" diff --git a/src/ai/backend/common/dto/clients/openai_compat/__init__.py b/src/ai/backend/common/dto/clients/openai_compat/__init__.py new file mode 100644 index 00000000000..2a84bf804ad --- /dev/null +++ b/src/ai/backend/common/dto/clients/openai_compat/__init__.py @@ -0,0 +1,26 @@ +"""DTOs for OpenAI-compatible Chat Completions endpoints. + +Wire-format Pydantic models for the OpenAI HTTP contract that vLLM, SGLang, +NVIDIA NIM, and TGI (in messages-api mode) implement. Kept under +``common/dto/clients/`` so any backend.ai component (CLI today, manager or +agent tomorrow) can consume the same types when talking to a deployed model. +""" + +from .request import ChatCompletionMessage, ChatCompletionRequest +from .response import ( + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseMessage, + ListModelsResponse, + ModelEntry, +) + +__all__ = ( + "ChatCompletionMessage", + "ChatCompletionRequest", + "ChatCompletionResponse", + "ChatCompletionResponseChoice", + "ChatCompletionResponseMessage", + "ListModelsResponse", + "ModelEntry", +) diff --git a/src/ai/backend/common/dto/clients/openai_compat/request.py b/src/ai/backend/common/dto/clients/openai_compat/request.py new file mode 100644 index 00000000000..0e86f202bc7 --- /dev/null +++ b/src/ai/backend/common/dto/clients/openai_compat/request.py @@ -0,0 +1,27 @@ +"""Request DTOs for OpenAI-compatible inference endpoints (vLLM, SGLang, NIM, TGI).""" + +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict + + +class ChatCompletionMessage(BaseModel): + """One message inside a chat-completions request.""" + + role: str + content: str + + +class ChatCompletionRequest(BaseModel): + """Body for ``POST /v1/chat/completions`` (OpenAI-compatible). + + Extra fields are forwarded so callers can pass runtime-variant-specific + knobs (e.g. ``temperature``, ``top_p``, vLLM/NIM extensions) through + ``./bai deployment chat --params`` without the CLI having to enumerate + them. + """ + + model_config = ConfigDict(extra="allow") + + model: str + messages: list[ChatCompletionMessage] diff --git a/src/ai/backend/common/dto/clients/openai_compat/response.py b/src/ai/backend/common/dto/clients/openai_compat/response.py new file mode 100644 index 00000000000..2ad39f4944b --- /dev/null +++ b/src/ai/backend/common/dto/clients/openai_compat/response.py @@ -0,0 +1,86 @@ +"""Response DTOs for OpenAI-compatible inference endpoints (vLLM, SGLang, NIM, TGI).""" + +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict + + +class _OpenAICompatModel(BaseModel): + """Base for OpenAI-compat response DTOs. + + Runtimes (vLLM, SGLang, NIM, TGI) ship runtime-specific extras + (``usage``, ``system_fingerprint``, ``tool_calls``, + ``reasoning_content``, ``prompt_logprobs``, ``owned_by``, ...). + ``extra="allow"`` keeps them on the model so ``model_dump_json`` + round-trips faithfully back to the CLI's stdout pretty-print. + """ + + model_config = ConfigDict(extra="allow") + + +class ModelEntry(_OpenAICompatModel): + """One entry in an OpenAI-compat ``GET /v1/models`` response.""" + + id: str + object: str = "model" + + +class ListModelsResponse(_OpenAICompatModel): + """Body of ``GET /v1/models`` on an OpenAI-compat endpoint.""" + + object: str = "list" + data: list[ModelEntry] + + +class ChatCompletionResponseMessage(_OpenAICompatModel): + """The ``message`` payload inside one OpenAI-compat choice. + + Only ``content`` is consumed by the CLI (for chat-history persistence); + runtime-specific fields like ``tool_calls`` or ``reasoning_content`` + (DeepSeek-R1, Qwen-QwQ) pass through to the JSON pretty-printed output + via the inherited ``extra="allow"``. + """ + + role: str | None = None + content: str | None = None + + +class ChatCompletionResponseChoice(_OpenAICompatModel): + """One entry in ``choices[]`` on a non-streaming chat-completion response. + + Streaming responses use ``delta`` instead of ``message``; this model + intentionally requires ``message`` so a streaming chunk that slips + through (the SDK never sets ``stream=true`` itself) fails loudly via + ``ValidationError`` rather than corrupting persisted history. + """ + + message: ChatCompletionResponseMessage + + +class ChatCompletionResponse(_OpenAICompatModel): + """Body of ``POST /v1/chat/completions`` (OpenAI-compatible). + + Only the path used by chat-history bookkeeping + (``choices[0].message.content``) is typed here. Top-level extras + (``id``, ``object``, ``created``, ``model``, ``usage``, + ``system_fingerprint``, runtime-specific fields) ride through via + the inherited ``extra="allow"`` so they survive the round-trip back + to the user's stdout when the CLI pretty-prints the response. + """ + + choices: list[ChatCompletionResponseChoice] + + @property + def assistant_message(self) -> str | None: + """Text emitted by the model in the first choice, if any. + + Returns ``None`` when the response advertised no choices or when + the assistant emitted only a tool-call (``message.content`` is + ``null`` in that case). The CLI uses this to gate chat-history + persistence: a half-recorded round (user logged but assistant + missing) would skew future context, so we skip the save in that + case. + """ + if not self.choices: + return None + return self.choices[0].message.content diff --git a/tests/component/prometheus_query_preset/test_prometheus_query_preset_preview.py b/tests/component/prometheus_query_preset/test_prometheus_query_preset_preview.py index 03e1a125e1b..84976faa510 100644 --- a/tests/component/prometheus_query_preset/test_prometheus_query_preset_preview.py +++ b/tests/component/prometheus_query_preset/test_prometheus_query_preset_preview.py @@ -37,7 +37,7 @@ async def test_returns_prometheus_response( query_template: str, result_type: str, ) -> None: - prometheus_client_mock.query_instant = AsyncMock( + prometheus_client_mock.preview_query_template = AsyncMock( return_value=PrometheusResponse( status="success", data=PrometheusQueryData(result_type=result_type, result=[]), @@ -50,14 +50,14 @@ async def test_returns_prometheus_response( assert result.status == "success" assert result.data.result_type == result_type - prometheus_client_mock.query_instant.assert_called_once() + prometheus_client_mock.preview_query_template.assert_called_once() async def test_propagates_prometheus_error( self, admin_v2_registry: V2ClientRegistry, prometheus_client_mock: MagicMock, ) -> None: - prometheus_client_mock.query_instant = AsyncMock( + prometheus_client_mock.preview_query_template = AsyncMock( side_effect=FailedToGetMetric('parse error: unexpected "}" (status=400, path=query)'), ) diff --git a/tests/unit/client/cli/test_deployment_chat_formatter.py b/tests/unit/client/cli/test_deployment_chat_formatter.py new file mode 100644 index 00000000000..b0209842cc5 --- /dev/null +++ b/tests/unit/client/cli/test_deployment_chat_formatter.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from uuid import uuid4 + +import pytest + +from ai.backend.client.cli.v2.deployment.chat.formatter import ( + DeploymentChatFormatter, + mask_token, +) +from ai.backend.client.cli.v2.deployment.chat.types import ( + DeploymentChatCacheEntry, + DeploymentChatConfigEntry, +) + + +class TestPrintConfig: + def test_prints_token_masked_and_model_when_set( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + DeploymentChatFormatter.print_config( + uuid4(), + DeploymentChatConfigEntry(token="sk-secret", model="llama-3-8b"), + ) + out = capsys.readouterr().out + assert "model : llama-3-8b" in out + assert "token : ********" in out + assert "sk-secret" not in out + + def test_prints_dashes_when_unset(self, capsys: pytest.CaptureFixture[str]) -> None: + DeploymentChatFormatter.print_config( + uuid4(), + DeploymentChatConfigEntry(token=None, model=None), + ) + out = capsys.readouterr().out + assert "model : -" in out + assert "token : " in out + + +class TestPrintCache: + def test_prints_endpoint_default_model_and_last_synced( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + DeploymentChatFormatter.print_cache( + uuid4(), + DeploymentChatCacheEntry( + endpoint_url="https://infer.example.test/api", + default_model="meta/test-model", + last_synced_at=datetime(2026, 4, 27, 12, 0, tzinfo=UTC), + ), + ) + out = capsys.readouterr().out + assert "endpoint_url : https://infer.example.test/api" in out + assert "default_model : meta/test-model" in out + assert "last_synced_at: 2026-04-27T12:00:00+00:00" in out + + def test_prints_dash_for_missing_default_model( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + DeploymentChatFormatter.print_cache( + uuid4(), + DeploymentChatCacheEntry( + endpoint_url="https://infer.example.test/api", + default_model=None, + last_synced_at=datetime(2026, 4, 27, 12, 0, tzinfo=UTC), + ), + ) + out = capsys.readouterr().out + assert "default_model : -" in out + + +class TestMaskToken: + def test_mask_long_token_returns_fixed_placeholder(self) -> None: + # Length-independent placeholder: never leak prefix, suffix, or length. + assert mask_token("sk-abcdefghijklmnopqrstuvwxyz") == "********" + + def test_mask_short_token_returns_fixed_placeholder(self) -> None: + assert mask_token("short") == "********" + + def test_mask_none(self) -> None: + assert mask_token(None) == "" diff --git a/tests/unit/client/cli/test_deployment_chat_types.py b/tests/unit/client/cli/test_deployment_chat_types.py new file mode 100644 index 00000000000..c844788b551 --- /dev/null +++ b/tests/unit/client/cli/test_deployment_chat_types.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from uuid import UUID, uuid4 + +import pytest + +from ai.backend.client.cli.v2.deployment.chat.types import ( + DeploymentChatCache, + DeploymentChatCacheEntry, + DeploymentChatConfig, + DeploymentChatConfigEntry, + DeploymentChatHistory, +) + + +@pytest.fixture +def cache() -> DeploymentChatCache: + return DeploymentChatCache() + + +@pytest.fixture +def chat_config() -> DeploymentChatConfig: + return DeploymentChatConfig() + + +@pytest.fixture +def chat_history() -> DeploymentChatHistory: + return DeploymentChatHistory() + + +@pytest.fixture +def history_now() -> datetime: + return datetime(2026, 4, 27, 12, 0, tzinfo=UTC) + + +@pytest.fixture +def cache_entry() -> DeploymentChatCacheEntry: + return DeploymentChatCacheEntry( + endpoint_url="https://infer.example.test/api", + default_model=None, + last_synced_at=datetime(2026, 4, 27, 12, 0, tzinfo=UTC), + ) + + +@pytest.fixture +def deployment_id() -> UUID: + return uuid4() + + +def _entry_with_model(default_model: str | None) -> DeploymentChatCacheEntry: + return DeploymentChatCacheEntry( + endpoint_url="https://infer.example.test/api", + default_model=default_model, + last_synced_at=datetime(2026, 4, 27, 12, 0, tzinfo=UTC), + ) + + +class TestCacheMutations: + def test_set_overwrites_existing_entry( + self, cache: DeploymentChatCache, deployment_id: UUID + ) -> None: + cache.set(deployment_id, _entry_with_model("m1")) + cache.set(deployment_id, _entry_with_model("m2")) + stored = cache.get(deployment_id) + assert stored is not None + assert stored.default_model == "m2" + + def test_delete_returns_true_when_present( + self, + cache: DeploymentChatCache, + cache_entry: DeploymentChatCacheEntry, + deployment_id: UUID, + ) -> None: + cache.set(deployment_id, cache_entry) + assert cache.delete(deployment_id) is True + assert cache.get(deployment_id) is None + + def test_delete_returns_false_when_absent( + self, cache: DeploymentChatCache, deployment_id: UUID + ) -> None: + assert cache.delete(deployment_id) is False + + +class TestConfigTokenStore: + def test_set_and_get_token( + self, chat_config: DeploymentChatConfig, deployment_id: UUID + ) -> None: + chat_config.set_token(deployment_id, "sk-abc") + assert chat_config.get_token(deployment_id) == "sk-abc" + + def test_set_overwrites_existing_token( + self, chat_config: DeploymentChatConfig, deployment_id: UUID + ) -> None: + chat_config.set_token(deployment_id, "sk-old") + chat_config.set_token(deployment_id, "sk-new") + assert chat_config.get_token(deployment_id) == "sk-new" + + def test_clear_token_returns_true_when_present( + self, chat_config: DeploymentChatConfig, deployment_id: UUID + ) -> None: + chat_config.set_token(deployment_id, "sk-x") + assert chat_config.clear_token(deployment_id) is True + assert chat_config.get_token(deployment_id) is None + + def test_clear_token_returns_false_when_absent( + self, chat_config: DeploymentChatConfig, deployment_id: UUID + ) -> None: + assert chat_config.clear_token(deployment_id) is False + + def test_clear_token_keeps_entry_when_model_remains( + self, chat_config: DeploymentChatConfig, deployment_id: UUID + ) -> None: + chat_config.set_token(deployment_id, "sk-x") + chat_config.set_model(deployment_id, "llama-3-8b") + assert chat_config.clear_token(deployment_id) is True + # Model side of the entry survives token removal. + entry = chat_config.get(deployment_id) + assert entry is not None + assert entry.token is None + assert entry.model == "llama-3-8b" + + +class TestConfigModelStore: + def test_set_and_get_model( + self, chat_config: DeploymentChatConfig, deployment_id: UUID + ) -> None: + chat_config.set_model(deployment_id, "llama-3-8b") + assert chat_config.get_model(deployment_id) == "llama-3-8b" + + def test_set_overwrites_existing_model( + self, chat_config: DeploymentChatConfig, deployment_id: UUID + ) -> None: + chat_config.set_model(deployment_id, "old-model") + chat_config.set_model(deployment_id, "new-model") + assert chat_config.get_model(deployment_id) == "new-model" + + def test_clear_model_returns_true_when_present( + self, chat_config: DeploymentChatConfig, deployment_id: UUID + ) -> None: + chat_config.set_model(deployment_id, "llama-3-8b") + assert chat_config.clear_model(deployment_id) is True + assert chat_config.get_model(deployment_id) is None + + def test_clear_model_returns_false_when_absent( + self, chat_config: DeploymentChatConfig, deployment_id: UUID + ) -> None: + assert chat_config.clear_model(deployment_id) is False + + def test_token_and_model_share_one_entry( + self, chat_config: DeploymentChatConfig, deployment_id: UUID + ) -> None: + chat_config.set_token(deployment_id, "sk-x") + chat_config.set_model(deployment_id, "llama-3-8b") + entry = chat_config.get(deployment_id) + assert entry == DeploymentChatConfigEntry(token="sk-x", model="llama-3-8b") + + +class TestConfigDelete: + def test_delete_removes_whole_entry( + self, chat_config: DeploymentChatConfig, deployment_id: UUID + ) -> None: + chat_config.set_token(deployment_id, "sk-x") + chat_config.set_model(deployment_id, "llama-3-8b") + assert chat_config.delete(deployment_id) is True + assert chat_config.get(deployment_id) is None + + def test_delete_returns_false_when_absent( + self, chat_config: DeploymentChatConfig, deployment_id: UUID + ) -> None: + assert chat_config.delete(deployment_id) is False + + +class TestHistoryAppendSlice: + def test_slice_returns_empty_when_no_entry( + self, chat_history: DeploymentChatHistory, deployment_id: UUID + ) -> None: + assert chat_history.slice(deployment_id, 5) == [] + + def test_slice_zero_limit_returns_empty_even_when_populated( + self, + chat_history: DeploymentChatHistory, + deployment_id: UUID, + history_now: datetime, + ) -> None: + chat_history.append(deployment_id, "user", "hi", created_at=history_now) + chat_history.append(deployment_id, "assistant", "hello", created_at=history_now) + # Limit 0 lets callers send a turn without context but still record it. + assert chat_history.slice(deployment_id, 0) == [] + + def test_slice_returns_in_insertion_order( + self, + chat_history: DeploymentChatHistory, + deployment_id: UUID, + history_now: datetime, + ) -> None: + for index in range(6): + chat_history.append( + deployment_id, + "user" if index % 2 == 0 else "assistant", + f"msg-{index}", + created_at=history_now, + ) + recent = chat_history.slice(deployment_id, 3) + assert [m.content for m in recent] == ["msg-3", "msg-4", "msg-5"] + + def test_slice_caps_to_available_length( + self, + chat_history: DeploymentChatHistory, + deployment_id: UUID, + history_now: datetime, + ) -> None: + chat_history.append(deployment_id, "user", "only", created_at=history_now) + # Asking for more than exists must not error out and must not pad. + assert [m.content for m in chat_history.slice(deployment_id, 10)] == ["only"] + + +class TestHistoryTruncation: + def test_append_drops_oldest_when_max_persisted_exceeded( + self, + chat_history: DeploymentChatHistory, + deployment_id: UUID, + history_now: datetime, + ) -> None: + # FIFO truncation: oldest is dropped first so the most recent + # context survives across long sessions. + for index in range(5): + chat_history.append( + deployment_id, + "user", + f"msg-{index}", + created_at=history_now, + max_persisted=3, + ) + stored = chat_history.get(deployment_id) + assert stored is not None + assert [m.content for m in stored] == ["msg-2", "msg-3", "msg-4"] + + +class TestHistoryClear: + def test_clear_returns_true_when_present( + self, + chat_history: DeploymentChatHistory, + deployment_id: UUID, + history_now: datetime, + ) -> None: + chat_history.append(deployment_id, "user", "hi", created_at=history_now) + assert chat_history.clear(deployment_id) is True + assert chat_history.get(deployment_id) is None + + def test_clear_returns_false_when_absent( + self, + chat_history: DeploymentChatHistory, + deployment_id: UUID, + ) -> None: + assert chat_history.clear(deployment_id) is False diff --git a/tests/unit/client/cli/test_deployment_chat_utils.py b/tests/unit/client/cli/test_deployment_chat_utils.py new file mode 100644 index 00000000000..0eda6f891bd --- /dev/null +++ b/tests/unit/client/cli/test_deployment_chat_utils.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +import json +from datetime import UTC, datetime +from pathlib import Path +from uuid import uuid4 + +import pytest + +from ai.backend.client.cli.v2.deployment.chat import types as chat_types +from ai.backend.client.cli.v2.deployment.chat import utils as chat_utils +from ai.backend.client.cli.v2.deployment.chat.types import ( + DeploymentChatCache, + DeploymentChatCacheEntry, + DeploymentChatConfig, + DeploymentChatHistory, +) + + +@pytest.fixture +def cache_path(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + path = tmp_path / "deployment_chat" / "cache.json" + path.parent.mkdir(parents=True, exist_ok=True) + # Both ``utils`` (where the path constant lives) and ``types`` (which + # imported it at module load time) must see the redirected path. + monkeypatch.setattr(chat_utils, "CHAT_CACHE_FILE", path) + monkeypatch.setattr(chat_types, "CHAT_CACHE_FILE", path) + return path + + +@pytest.fixture +def config_path(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + path = tmp_path / "deployment_chat" / "config.json" + path.parent.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(chat_utils, "CHAT_CONFIG_FILE", path) + monkeypatch.setattr(chat_types, "CHAT_CONFIG_FILE", path) + return path + + +@pytest.fixture +def history_path(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + path = tmp_path / "deployment_chat" / "history.json" + path.parent.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(chat_utils, "CHAT_HISTORY_FILE", path) + monkeypatch.setattr(chat_types, "CHAT_HISTORY_FILE", path) + return path + + +def _make_entry( + *, + endpoint: str = "https://infer.example.test/api", + default_model: str | None = None, +) -> DeploymentChatCacheEntry: + return DeploymentChatCacheEntry( + endpoint_url=endpoint, + default_model=default_model, + last_synced_at=datetime(2026, 4, 27, 12, 0, tzinfo=UTC), + ) + + +class TestCacheLoadSaveRoundTrip: + def test_load_returns_empty_when_file_missing(self, cache_path: Path) -> None: + assert DeploymentChatCache.load().deployments == {} + + def test_save_then_load_preserves_entry(self, cache_path: Path) -> None: + cache = DeploymentChatCache() + dep_id = uuid4() + original = _make_entry(default_model="gpt-test") + cache.set(dep_id, original) + cache.save() + + loaded = DeploymentChatCache.load() + restored = loaded.deployments[dep_id] + assert restored.endpoint_url == original.endpoint_url + assert restored.default_model == original.default_model + assert restored.last_synced_at == original.last_synced_at + + +class TestConfigLoadSaveRoundTrip: + def test_load_returns_empty_when_file_missing(self, config_path: Path) -> None: + assert DeploymentChatConfig.load().deployments == {} + + def test_save_then_load_preserves_token_and_model(self, config_path: Path) -> None: + cfg = DeploymentChatConfig() + dep_id = uuid4() + cfg.set_token(dep_id, "sk-secret-token-1234") + cfg.set_model(dep_id, "llama-3-8b-instruct") + cfg.save() + + loaded = DeploymentChatConfig.load() + assert loaded.get_token(dep_id) == "sk-secret-token-1234" + assert loaded.get_model(dep_id) == "llama-3-8b-instruct" + + +class TestCacheLoaderResilience: + def test_load_returns_empty_on_corrupt_json(self, cache_path: Path) -> None: + cache_path.write_text("not-json{", encoding="utf-8") + assert DeploymentChatCache.load().deployments == {} + + def test_load_returns_empty_when_top_level_not_object(self, cache_path: Path) -> None: + cache_path.write_text("[]", encoding="utf-8") + assert DeploymentChatCache.load().deployments == {} + + def test_load_returns_empty_on_invalid_uuid_key(self, cache_path: Path) -> None: + cache_path.write_text( + json.dumps({ + "deployments": { + "not-a-uuid": { + "endpoint_url": "https://x.example", + "default_model": None, + "last_synced_at": "2026-04-27T12:00:00+00:00", + }, + }, + }), + encoding="utf-8", + ) + assert DeploymentChatCache.load().deployments == {} + + def test_load_returns_empty_on_malformed_entry_payload(self, cache_path: Path) -> None: + cache_path.write_text( + json.dumps({ + "deployments": { + "12345678-1234-5678-1234-567812345678": {"default_model": "m"}, + }, + }), + encoding="utf-8", + ) + assert DeploymentChatCache.load().deployments == {} + + +class TestConfigLoaderResilience: + def test_load_returns_empty_on_corrupt_json(self, config_path: Path) -> None: + config_path.write_text("not-json{", encoding="utf-8") + assert DeploymentChatConfig.load().deployments == {} + + def test_load_returns_empty_on_invalid_uuid_key(self, config_path: Path) -> None: + config_path.write_text( + json.dumps({ + "deployments": {"not-a-uuid": {"token": "sk-x", "model": None}}, + }), + encoding="utf-8", + ) + assert DeploymentChatConfig.load().deployments == {} + + +class TestHistoryLoadSaveRoundTrip: + def test_load_returns_empty_when_file_missing(self, history_path: Path) -> None: + assert DeploymentChatHistory.load().deployments == {} + + def test_save_then_load_preserves_messages(self, history_path: Path) -> None: + history = DeploymentChatHistory() + dep_id = uuid4() + now = datetime(2026, 4, 27, 12, 0, tzinfo=UTC) + history.append(dep_id, "user", "hello", created_at=now) + history.append(dep_id, "assistant", "world", created_at=now) + history.save() + + loaded = DeploymentChatHistory.load() + messages = loaded.get(dep_id) + assert messages is not None + assert [(m.role, m.content) for m in messages] == [ + ("user", "hello"), + ("assistant", "world"), + ] + assert messages[0].created_at == now + + +class TestHistoryLoaderResilience: + def test_load_returns_empty_on_corrupt_json(self, history_path: Path) -> None: + history_path.write_text("not-json{", encoding="utf-8") + assert DeploymentChatHistory.load().deployments == {} + + def test_load_returns_empty_when_top_level_not_object(self, history_path: Path) -> None: + history_path.write_text("[]", encoding="utf-8") + assert DeploymentChatHistory.load().deployments == {} + + def test_load_returns_empty_on_invalid_uuid_key(self, history_path: Path) -> None: + history_path.write_text( + json.dumps({ + "deployments": { + "not-a-uuid": [ + { + "role": "user", + "content": "hi", + "created_at": "2026-04-27T12:00:00+00:00", + }, + ], + }, + }), + encoding="utf-8", + ) + assert DeploymentChatHistory.load().deployments == {} + + def test_load_returns_empty_on_malformed_message_payload(self, history_path: Path) -> None: + history_path.write_text( + json.dumps({ + "deployments": { + "12345678-1234-5678-1234-567812345678": [{"role": "user"}], + }, + }), + encoding="utf-8", + ) + assert DeploymentChatHistory.load().deployments == {} diff --git a/tests/unit/client/v2/test_deployment_chat_client.py b/tests/unit/client/v2/test_deployment_chat_client.py new file mode 100644 index 00000000000..2c0e15ece5c --- /dev/null +++ b/tests/unit/client/v2/test_deployment_chat_client.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Any + +import pytest +from aioresponses import aioresponses +from pydantic import ValidationError +from yarl import URL + +from ai.backend.client.exceptions import BackendAPIError, BackendClientError +from ai.backend.client.v2.config import ClientConfig +from ai.backend.client.v2.deployment_chat import DeploymentChatClient +from ai.backend.client.v2.exceptions import DeploymentAuthError +from ai.backend.common.dto.clients.openai_compat import ChatCompletionResponse + +BASE_URL = "http://infer.test.local" +CHAT_URL = f"{BASE_URL}/v1/chat/completions" + + +@pytest.fixture +async def chat_client() -> AsyncIterator[DeploymentChatClient]: + # ``endpoint`` is required on ClientConfig but unused by AppProxyClient. + config = ClientConfig(endpoint=URL("http://manager.unused")) + client = DeploymentChatClient(config) + try: + yield client + finally: + await client.close() + + +def _make_body() -> dict[str, Any]: + return { + "model": "meta/test-model", + "messages": [{"role": "user", "content": "hello"}], + } + + +def _last_call(mock: aioresponses, method: str, url: str) -> Any: + """Return the most recent ``RequestCall`` aioresponses captured for (method, url).""" + key = (method.upper(), URL(url)) + calls = mock.requests[key] + assert calls, f"no request was captured for {method} {url}" + return calls[-1] + + +class TestChatCompletionSuccess: + async def test_posts_to_v1_chat_completions_with_bearer_header( + self, chat_client: DeploymentChatClient + ) -> None: + with aioresponses() as m: + m.post( + CHAT_URL, + payload={ + "id": "cmpl-1", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "hi"}, + "finish_reason": "stop", + } + ], + }, + ) + resp = await chat_client.chat_completion(BASE_URL, "sk-test-token", _make_body()) + call = _last_call(m, "POST", CHAT_URL) + + assert call.kwargs["headers"]["Authorization"] == "Bearer sk-test-token" + assert call.kwargs["headers"]["Content-Type"] == "application/json" + assert call.kwargs["json"] == _make_body() + assert isinstance(resp, ChatCompletionResponse) + assert resp.choices[0].message.content == "hi" + assert resp.assistant_message == "hi" + + async def test_endpoint_url_already_ending_in_chat_completions( + self, chat_client: DeploymentChatClient + ) -> None: + with aioresponses() as m: + m.post(CHAT_URL, payload={"choices": []}) + await chat_client.chat_completion(CHAT_URL, "sk-x", _make_body()) + assert (("POST", URL(CHAT_URL))) in m.requests + + async def test_endpoint_url_with_trailing_slash_is_normalized( + self, chat_client: DeploymentChatClient + ) -> None: + with aioresponses() as m: + m.post(CHAT_URL, payload={"choices": []}) + await chat_client.chat_completion(f"{CHAT_URL}/", "sk-x", _make_body()) + assert (("POST", URL(CHAT_URL))) in m.requests + + async def test_omits_authorization_when_token_is_none( + self, chat_client: DeploymentChatClient + ) -> None: + with aioresponses() as m: + m.post(CHAT_URL, payload={"choices": []}) + await chat_client.chat_completion(BASE_URL, None, _make_body()) + call = _last_call(m, "POST", CHAT_URL) + assert "Authorization" not in call.kwargs["headers"] + + +class TestAuthErrors: + async def test_401_raises_DeploymentAuthError(self, chat_client: DeploymentChatClient) -> None: + with aioresponses() as m: + m.post(CHAT_URL, status=401, payload={"error": "invalid api key"}) + with pytest.raises(DeploymentAuthError) as exc_info: + await chat_client.chat_completion(BASE_URL, "bad", _make_body()) + assert exc_info.value.status == 401 + + async def test_403_raises_DeploymentAuthError(self, chat_client: DeploymentChatClient) -> None: + with aioresponses() as m: + m.post(CHAT_URL, status=403, payload={"error": "forbidden"}) + with pytest.raises(DeploymentAuthError): + await chat_client.chat_completion(BASE_URL, "bad", _make_body()) + + +class TestServerErrors: + async def test_500_raises_BackendAPIError_not_auth( + self, chat_client: DeploymentChatClient + ) -> None: + with aioresponses() as m: + m.post(CHAT_URL, status=500, payload={"error": "boom"}) + with pytest.raises(BackendAPIError) as exc_info: + await chat_client.chat_completion(BASE_URL, "sk", _make_body()) + assert not isinstance(exc_info.value, DeploymentAuthError) + assert exc_info.value.status == 500 + + +class TestNonJsonResponse: + async def test_non_json_2xx_raises_client_error( + self, chat_client: DeploymentChatClient + ) -> None: + with aioresponses() as m: + m.post(CHAT_URL, status=200, body="not-json", content_type="text/plain") + with pytest.raises(BackendClientError): + await chat_client.chat_completion(BASE_URL, "sk", _make_body()) + + async def test_html_5xx_raises_backend_api_error_with_body( + self, chat_client: DeploymentChatClient + ) -> None: + # app-proxy / cloud LB error pages: 5xx with HTML body. The HTTP + # status carries the meaningful failure signal, so this surfaces as + # BackendAPIError with the raw body in ``detail``. + with aioresponses() as m: + m.post( + CHAT_URL, + status=502, + body="Bad Gateway", + content_type="text/html", + ) + with pytest.raises(BackendAPIError) as exc_info: + await chat_client.chat_completion(BASE_URL, "sk", _make_body()) + assert exc_info.value.status == 502 + assert "Bad Gateway" in exc_info.value.data["detail"] + + +class TestChatCompletionResponseModel: + """Direct coverage for the response Pydantic model. + + ``DeploymentChatClient.chat_completion`` runs ``model_validate`` on the + payload, so failures here surface as ``ValidationError`` at the SDK + boundary instead of corrupting persisted chat history downstream. + """ + + def test_assistant_message_returns_first_choice_text(self) -> None: + resp = ChatCompletionResponse.model_validate({ + "choices": [ + {"message": {"role": "assistant", "content": "hi 길동"}}, + {"message": {"role": "assistant", "content": "ignored"}}, + ], + }) + assert resp.assistant_message == "hi 길동" + + def test_assistant_message_none_when_choices_empty(self) -> None: + # vLLM emits choices=[] on certain error paths; the CLI uses this + # to skip half-recorded history rounds. + resp = ChatCompletionResponse.model_validate({"choices": []}) + assert resp.assistant_message is None + + def test_assistant_message_none_for_tool_call_only_response(self) -> None: + # Function-calling responses leave message.content as null and put + # the call in tool_calls; nothing text-shaped to persist. + resp = ChatCompletionResponse.model_validate({ + "choices": [ + { + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "lookup", "arguments": "{}"}, + }, + ], + }, + }, + ], + }) + assert resp.assistant_message is None + + def test_extra_top_level_fields_round_trip(self) -> None: + # Runtime-specific telemetry (usage, system_fingerprint, vLLM + # prompt_logprobs, NIM extras) must survive parsing so the CLI's + # JSON pretty-print still shows them to the user. + payload: dict[str, Any] = { + "id": "chatcmpl-1", + "object": "chat.completion", + "created": 1741569952, + "model": "vllm/test", + "choices": [{"message": {"role": "assistant", "content": "hi"}}], + "usage": {"prompt_tokens": 4, "completion_tokens": 1, "total_tokens": 5}, + "system_fingerprint": "fp_xyz", + } + resp = ChatCompletionResponse.model_validate(payload) + dumped = resp.model_dump(mode="json") + assert dumped["usage"]["total_tokens"] == 5 + assert dumped["system_fingerprint"] == "fp_xyz" + assert dumped["model"] == "vllm/test" + + def test_streaming_chunk_shape_fails_validation(self) -> None: + # ``delta`` is the streaming-chunk shape; the SDK never sets + # stream=true, so its arrival means the server misbehaved. + # Failing loudly is preferable to silently dropping the round. + with pytest.raises(ValidationError): + ChatCompletionResponse.model_validate({ + "choices": [ + {"delta": {"role": "assistant", "content": "partial"}}, + ], + })