Skip to content

Commit 078a69a

Browse files
committed
refactor(BA-5528): replace per-flag sampling args with --params JSON passthrough
Sampling parameters such as temperature, top_p, max_tokens, frequency_penalty, presence_penalty, seed, and stop are runtime-variant-specific (vllm, tgi, sglang, nim accept different sets and even different field names like max_completion_tokens vs max_new_tokens). Exposing them as individual Click options bakes the OpenAI/vllm shape into the CLI, so any runtime that adds or renames a parameter would need a CLI change. Drop all sampling flags and accept a single --params option that takes either an inline JSON object string or '@path' to read JSON from a file. The parsed object is merged into the request body before model and messages are written, so --model and the positional CONTENT always win over anything in --params.
1 parent d7a5bf0 commit 078a69a

1 file changed

Lines changed: 42 additions & 68 deletions

File tree

  • src/ai/backend/client/cli/v2/deployment

src/ai/backend/client/cli/v2/deployment/chat.py

Lines changed: 42 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import asyncio
66
from collections.abc import Awaitable, Callable
77
from datetime import UTC, datetime
8+
from pathlib import Path
89
from typing import Any
910
from uuid import UUID
1011

@@ -34,6 +35,26 @@ def _run_async(coro_fn: Callable[[], Awaitable[None]]) -> None:
3435
raise click.ClickException(f"{status}: {detail}") from e
3536

3637

38+
def _parse_params(spec: str) -> dict[str, Any]:
39+
"""Parse the ``--params`` value as a JSON object (or ``@/path/to/file.json``)."""
40+
import json
41+
42+
if spec.startswith("@"):
43+
try:
44+
text = Path(spec[1:]).read_text(encoding="utf-8")
45+
except OSError as e:
46+
raise click.ClickException(f"--params file not readable: {e}") from e
47+
else:
48+
text = spec
49+
try:
50+
parsed = json.loads(text)
51+
except json.JSONDecodeError as e:
52+
raise click.ClickException(f"--params is not valid JSON: {e.msg}") from e
53+
if not isinstance(parsed, dict):
54+
raise click.ClickException("--params must be a JSON object.")
55+
return parsed
56+
57+
3758
@click.command(name="chat")
3859
@click.argument("deployment_id", type=click.UUID)
3960
@click.argument("content", type=str)
@@ -44,60 +65,29 @@ def _run_async(coro_fn: Callable[[], Awaitable[None]]) -> None:
4465
help="Model name to send (defaults to cached default_model).",
4566
)
4667
@click.option(
47-
"--temperature",
48-
default=None,
49-
type=click.FloatRange(min=0.0, max=2.0),
50-
help="Sampling temperature.",
51-
)
52-
@click.option(
53-
"--top-p",
54-
default=None,
55-
type=click.FloatRange(min=0.0, max=1.0),
56-
help="Nucleus sampling probability mass.",
57-
)
58-
@click.option(
59-
"--frequency-penalty",
68+
"--params",
69+
"params_spec",
6070
default=None,
61-
type=click.FloatRange(min=-2.0, max=2.0),
62-
help="Penalty for token frequency.",
63-
)
64-
@click.option(
65-
"--presence-penalty",
66-
default=None,
67-
type=click.FloatRange(min=-2.0, max=2.0),
68-
help="Penalty for token presence.",
69-
)
70-
@click.option(
71-
"--seed",
72-
default=None,
73-
type=int,
74-
help="Random seed for deterministic sampling.",
75-
)
76-
@click.option(
77-
"--stop",
78-
multiple=True,
7971
type=str,
80-
help="Stop sequence (repeatable).",
81-
)
82-
@click.option(
83-
"--max-tokens",
84-
default=None,
85-
type=click.IntRange(min=1),
86-
help="Maximum number of tokens to generate.",
72+
help=(
73+
"Extra request-body fields as a JSON object, or '@PATH' to read from a file. "
74+
"Forwarded to the inference endpoint as-is "
75+
'(e.g. \'{"temperature": 0.7, "max_tokens": 256}\'). '
76+
"The 'model' and 'messages' fields are always overridden by --model and CONTENT."
77+
),
8778
)
8879
def chat(
8980
deployment_id: UUID,
9081
content: str,
9182
model: str | None,
92-
temperature: float | None,
93-
top_p: float | None,
94-
frequency_penalty: float | None,
95-
presence_penalty: float | None,
96-
seed: int | None,
97-
stop: tuple[str, ...],
98-
max_tokens: int | None,
83+
params_spec: str | None,
9984
) -> None:
100-
"""Send a one-shot chat completion request to a deployed model."""
85+
"""Send a one-shot chat completion request to a deployed model.
86+
87+
Sampling parameters (temperature, top_p, max_tokens, etc.) are not
88+
exposed as individual flags because their schema differs across
89+
runtime variants. Pass them through ``--params`` instead.
90+
"""
10191
import json
10292
import sys
10393

@@ -113,6 +103,7 @@ def chat(
113103
except IncompatibleChatCacheError as e:
114104
raise click.ClickException(str(e)) from e
115105

106+
extra_body = _parse_params(params_spec) if params_spec else {}
116107
entry = cache.get(deployment_id)
117108

118109
async def _resolve_endpoint() -> DeploymentChatCacheEntry:
@@ -139,27 +130,6 @@ async def _resolve_endpoint() -> DeploymentChatCacheEntry:
139130
save_chat_cache(cache)
140131
return new_entry
141132

142-
def _build_request_body(model_name: str) -> dict[str, Any]:
143-
body: dict[str, Any] = {
144-
"model": model_name,
145-
"messages": [{"role": "user", "content": content}],
146-
}
147-
if temperature is not None:
148-
body["temperature"] = temperature
149-
if top_p is not None:
150-
body["top_p"] = top_p
151-
if frequency_penalty is not None:
152-
body["frequency_penalty"] = frequency_penalty
153-
if presence_penalty is not None:
154-
body["presence_penalty"] = presence_penalty
155-
if seed is not None:
156-
body["seed"] = seed
157-
if stop:
158-
body["stop"] = list(stop)
159-
if max_tokens is not None:
160-
body["max_tokens"] = max_tokens
161-
return body
162-
163133
async def _run() -> None:
164134
from ai.backend.client.exceptions import BackendAPIError
165135

@@ -173,7 +143,11 @@ async def _run() -> None:
173143
"(this auto-discovers the served model from the inference endpoint)."
174144
)
175145

176-
body = _build_request_body(request_model)
146+
body: dict[str, Any] = {
147+
**extra_body,
148+
"model": request_model,
149+
"messages": [{"role": "user", "content": content}],
150+
}
177151
async with InferenceChatClient(
178152
skip_ssl_verification=config.skip_ssl_verification,
179153
) as client:

0 commit comments

Comments
 (0)