|
1 | | -"""User-facing CLI: ``./bai deployment chat-config`` (manage local chat cache).""" |
| 1 | +"""User-facing CLI: ``./bai deployment chat`` and ``chat-config``.""" |
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
|
11 | 11 | import click |
12 | 12 | from pydantic import BaseModel, ConfigDict, Field, ValidationError |
13 | 13 |
|
14 | | -from ai.backend.client.cli.v2.deployment_chat_cache import ( |
| 14 | +from ai.backend.cli.params import JSONParamType |
| 15 | +from ai.backend.client.cli.v2.deployment.chat.types import ( |
15 | 16 | DeploymentChatCacheEntry, |
16 | 17 | IncompatibleChatCacheError, |
17 | | - load_chat_cache, |
18 | | - save_chat_cache, |
19 | | -) |
20 | | -from ai.backend.client.cli.v2.deployment_chat_config import ( |
21 | 18 | IncompatibleChatConfigError, |
| 19 | +) |
| 20 | +from ai.backend.client.cli.v2.deployment.chat.utils import ( |
| 21 | + load_chat_cache, |
22 | 22 | load_chat_config, |
23 | 23 | mask_token, |
| 24 | + save_chat_cache, |
24 | 25 | save_chat_config, |
25 | 26 | ) |
26 | 27 | from ai.backend.client.cli.v2.helpers import create_v2_registry, load_v2_config |
@@ -50,6 +51,134 @@ def _run_async(coro_fn: Callable[[], Coroutine[Any, Any, None]]) -> None: |
50 | 51 | raise click.ClickException(f"{status}: {detail}") from e |
51 | 52 |
|
52 | 53 |
|
| 54 | +# --------------------------------------------------------------------------- |
| 55 | +# chat |
| 56 | +# --------------------------------------------------------------------------- |
| 57 | + |
| 58 | + |
| 59 | +@click.command(name="chat") |
| 60 | +@click.argument("deployment_id", type=click.UUID) |
| 61 | +@click.argument("content", type=str) |
| 62 | +@click.option( |
| 63 | + "--model", |
| 64 | + default=None, |
| 65 | + type=str, |
| 66 | + help="Model name to send (defaults to cached default_model).", |
| 67 | +) |
| 68 | +@click.option( |
| 69 | + "--params", |
| 70 | + default="{}", |
| 71 | + type=JSONParamType(), |
| 72 | + help=( |
| 73 | + "Extra request-body fields as a JSON object. " |
| 74 | + "Forwarded to the inference endpoint as-is " |
| 75 | + '(e.g. \'{"temperature": 0.7, "max_tokens": 256}\'). ' |
| 76 | + "The 'model' and 'messages' fields are always overridden by --model and CONTENT." |
| 77 | + ), |
| 78 | +) |
| 79 | +def chat( |
| 80 | + deployment_id: UUID, |
| 81 | + content: str, |
| 82 | + model: str | None, |
| 83 | + params: Any, |
| 84 | +) -> None: |
| 85 | + """Send a one-shot chat completion request to a deployed model. |
| 86 | +
|
| 87 | + Sampling parameters (temperature, top_p, max_tokens, etc.) are not |
| 88 | + exposed as individual flags because their schema differs across |
| 89 | + runtime variants. Pass them through ``--params`` instead. |
| 90 | + """ |
| 91 | + import json |
| 92 | + |
| 93 | + from ai.backend.client.v2.deployment_chat import ( |
| 94 | + DeploymentChatAuthError, |
| 95 | + DeploymentChatClient, |
| 96 | + ) |
| 97 | + |
| 98 | + connection = load_v2_config() |
| 99 | + |
| 100 | + try: |
| 101 | + cache = load_chat_cache() |
| 102 | + chat_config_store = load_chat_config() |
| 103 | + except (IncompatibleChatCacheError, IncompatibleChatConfigError) as e: |
| 104 | + raise click.ClickException(str(e)) from e |
| 105 | + |
| 106 | + if not isinstance(params, dict): |
| 107 | + raise click.ClickException("--params must be a JSON object.") |
| 108 | + extra_body: dict[str, Any] = params |
| 109 | + entry = cache.get(deployment_id) |
| 110 | + |
| 111 | + async def _ensure_endpoint_entry() -> DeploymentChatCacheEntry: |
| 112 | + if entry is not None and entry.endpoint_url: |
| 113 | + return entry |
| 114 | + registry = await create_v2_registry(connection) |
| 115 | + try: |
| 116 | + deployment = await registry.deployment.get(deployment_id) |
| 117 | + finally: |
| 118 | + await registry.close() |
| 119 | + endpoint_url = deployment.network_access.endpoint_url |
| 120 | + if not endpoint_url: |
| 121 | + raise click.ClickException( |
| 122 | + f"Deployment {deployment_id} has no endpoint_url yet " |
| 123 | + "(it may still be provisioning). Wait until the deployment is READY." |
| 124 | + ) |
| 125 | + new_entry = DeploymentChatCacheEntry( |
| 126 | + endpoint_url=endpoint_url, |
| 127 | + default_model=entry.default_model if entry is not None else None, |
| 128 | + last_synced_at=datetime.now(UTC), |
| 129 | + ) |
| 130 | + cache.upsert(deployment_id, new_entry) |
| 131 | + save_chat_cache(cache) |
| 132 | + return new_entry |
| 133 | + |
| 134 | + async def _run() -> None: |
| 135 | + from ai.backend.client.exceptions import BackendAPIError |
| 136 | + |
| 137 | + endpoint_entry = await _ensure_endpoint_entry() |
| 138 | + request_model = model or endpoint_entry.default_model |
| 139 | + if request_model is None: |
| 140 | + raise click.ClickException( |
| 141 | + f"No --model given and no default_model cached for deployment {deployment_id}.\n" |
| 142 | + "Set one with:\n" |
| 143 | + f" ./bai deployment chat-config set {deployment_id} --token <api_key>\n" |
| 144 | + "(this auto-discovers the served model from the inference endpoint)." |
| 145 | + ) |
| 146 | + |
| 147 | + body: dict[str, Any] = { |
| 148 | + **extra_body, |
| 149 | + "model": request_model, |
| 150 | + "messages": [{"role": "user", "content": content}], |
| 151 | + } |
| 152 | + api_key = chat_config_store.get_token(deployment_id) |
| 153 | + async with DeploymentChatClient( |
| 154 | + skip_ssl_verification=connection.skip_ssl_verification, |
| 155 | + ) as client: |
| 156 | + try: |
| 157 | + response = await client.chat_completion( |
| 158 | + endpoint_entry.endpoint_url, |
| 159 | + api_key, |
| 160 | + body, |
| 161 | + ) |
| 162 | + except DeploymentChatAuthError as e: |
| 163 | + raise click.ClickException( |
| 164 | + f"The inference endpoint rejected the configured API key for " |
| 165 | + f"deployment {deployment_id}. Re-register with:\n" |
| 166 | + f" ./bai deployment chat-config set {deployment_id} --token <api_key>" |
| 167 | + ) from e |
| 168 | + except BackendAPIError as e: |
| 169 | + raise click.ClickException( |
| 170 | + f"Inference endpoint error ({e.status} {e.reason}): {e.data}" |
| 171 | + ) from e |
| 172 | + print(json.dumps(response, indent=2, ensure_ascii=False, default=str)) |
| 173 | + |
| 174 | + _run_async(_run) |
| 175 | + |
| 176 | + |
| 177 | +# --------------------------------------------------------------------------- |
| 178 | +# chat-config |
| 179 | +# --------------------------------------------------------------------------- |
| 180 | + |
| 181 | + |
53 | 182 | @click.group(name="chat-config") |
54 | 183 | def chat_config() -> None: |
55 | 184 | """Manage stored API keys and discovered model names for deployment chat. |
@@ -147,12 +276,7 @@ async def _discover_model( |
147 | 276 | skip_ssl_verification: bool, |
148 | 277 | fallback: str | None, |
149 | 278 | ) -> str | None: |
150 | | - """Call ``GET {endpoint}/v1/models`` to learn the served model name. |
151 | | -
|
152 | | - Returns the first model id reported by the inference endpoint. Falls |
153 | | - back to *fallback* when the endpoint is unreachable or the response |
154 | | - does not contain any model entries. |
155 | | - """ |
| 279 | + """Call ``GET {endpoint}/v1/models`` to learn the served model name.""" |
156 | 280 | from ai.backend.client.exceptions import BackendAPIError, BackendClientError |
157 | 281 | from ai.backend.client.v2.deployment_chat import ( |
158 | 282 | DeploymentChatAuthError, |
@@ -236,4 +360,4 @@ def _print_entry( |
236 | 360 | print(f"api_key : {mask_token(token)}") |
237 | 361 |
|
238 | 362 |
|
239 | | -__all__ = ("chat_config",) |
| 363 | +__all__ = ("chat", "chat_config") |
0 commit comments