55import asyncio
66from collections .abc import Awaitable , Callable
77from datetime import UTC , datetime
8+ from pathlib import Path
89from typing import Any
910from uuid import UUID
1011
@@ -34,6 +35,26 @@ def _run_async(coro_fn: Callable[[], Awaitable[None]]) -> None:
3435 raise click .ClickException (f"{ status } : { detail } " ) from e
3536
3637
38+ def _parse_params (spec : str ) -> dict [str , Any ]:
39+ """Parse the ``--params`` value as a JSON object (or ``@/path/to/file.json``)."""
40+ import json
41+
42+ if spec .startswith ("@" ):
43+ try :
44+ text = Path (spec [1 :]).read_text (encoding = "utf-8" )
45+ except OSError as e :
46+ raise click .ClickException (f"--params file not readable: { e } " ) from e
47+ else :
48+ text = spec
49+ try :
50+ parsed = json .loads (text )
51+ except json .JSONDecodeError as e :
52+ raise click .ClickException (f"--params is not valid JSON: { e .msg } " ) from e
53+ if not isinstance (parsed , dict ):
54+ raise click .ClickException ("--params must be a JSON object." )
55+ return parsed
56+
57+
3758@click .command (name = "chat" )
3859@click .argument ("deployment_id" , type = click .UUID )
3960@click .argument ("content" , type = str )
@@ -44,60 +65,29 @@ def _run_async(coro_fn: Callable[[], Awaitable[None]]) -> None:
4465 help = "Model name to send (defaults to cached default_model)." ,
4566)
4667@click .option (
47- "--temperature" ,
48- default = None ,
49- type = click .FloatRange (min = 0.0 , max = 2.0 ),
50- help = "Sampling temperature." ,
51- )
52- @click .option (
53- "--top-p" ,
54- default = None ,
55- type = click .FloatRange (min = 0.0 , max = 1.0 ),
56- help = "Nucleus sampling probability mass." ,
57- )
58- @click .option (
59- "--frequency-penalty" ,
68+ "--params" ,
69+ "params_spec" ,
6070 default = None ,
61- type = click .FloatRange (min = - 2.0 , max = 2.0 ),
62- help = "Penalty for token frequency." ,
63- )
64- @click .option (
65- "--presence-penalty" ,
66- default = None ,
67- type = click .FloatRange (min = - 2.0 , max = 2.0 ),
68- help = "Penalty for token presence." ,
69- )
70- @click .option (
71- "--seed" ,
72- default = None ,
73- type = int ,
74- help = "Random seed for deterministic sampling." ,
75- )
76- @click .option (
77- "--stop" ,
78- multiple = True ,
7971 type = str ,
80- help = "Stop sequence (repeatable)." ,
81- )
82- @click .option (
83- "--max-tokens" ,
84- default = None ,
85- type = click .IntRange (min = 1 ),
86- help = "Maximum number of tokens to generate." ,
72+ help = (
73+ "Extra request-body fields as a JSON object, or '@PATH' to read from a file. "
74+ "Forwarded to the inference endpoint as-is "
75+ '(e.g. \' {"temperature": 0.7, "max_tokens": 256}\' ). '
76+ "The 'model' and 'messages' fields are always overridden by --model and CONTENT."
77+ ),
8778)
8879def chat (
8980 deployment_id : UUID ,
9081 content : str ,
9182 model : str | None ,
92- temperature : float | None ,
93- top_p : float | None ,
94- frequency_penalty : float | None ,
95- presence_penalty : float | None ,
96- seed : int | None ,
97- stop : tuple [str , ...],
98- max_tokens : int | None ,
83+ params_spec : str | None ,
9984) -> None :
100- """Send a one-shot chat completion request to a deployed model."""
85+ """Send a one-shot chat completion request to a deployed model.
86+
87+ Sampling parameters (temperature, top_p, max_tokens, etc.) are not
88+ exposed as individual flags because their schema differs across
89+ runtime variants. Pass them through ``--params`` instead.
90+ """
10191 import json
10292 import sys
10393
@@ -113,6 +103,7 @@ def chat(
113103 except IncompatibleChatCacheError as e :
114104 raise click .ClickException (str (e )) from e
115105
106+ extra_body = _parse_params (params_spec ) if params_spec else {}
116107 entry = cache .get (deployment_id )
117108
118109 async def _resolve_endpoint () -> DeploymentChatCacheEntry :
@@ -139,27 +130,6 @@ async def _resolve_endpoint() -> DeploymentChatCacheEntry:
139130 save_chat_cache (cache )
140131 return new_entry
141132
142- def _build_request_body (model_name : str ) -> dict [str , Any ]:
143- body : dict [str , Any ] = {
144- "model" : model_name ,
145- "messages" : [{"role" : "user" , "content" : content }],
146- }
147- if temperature is not None :
148- body ["temperature" ] = temperature
149- if top_p is not None :
150- body ["top_p" ] = top_p
151- if frequency_penalty is not None :
152- body ["frequency_penalty" ] = frequency_penalty
153- if presence_penalty is not None :
154- body ["presence_penalty" ] = presence_penalty
155- if seed is not None :
156- body ["seed" ] = seed
157- if stop :
158- body ["stop" ] = list (stop )
159- if max_tokens is not None :
160- body ["max_tokens" ] = max_tokens
161- return body
162-
163133 async def _run () -> None :
164134 from ai .backend .client .exceptions import BackendAPIError
165135
@@ -173,7 +143,11 @@ async def _run() -> None:
173143 "(this auto-discovers the served model from the inference endpoint)."
174144 )
175145
176- body = _build_request_body (request_model )
146+ body : dict [str , Any ] = {
147+ ** extra_body ,
148+ "model" : request_model ,
149+ "messages" : [{"role" : "user" , "content" : content }],
150+ }
177151 async with InferenceChatClient (
178152 skip_ssl_verification = config .skip_ssl_verification ,
179153 ) as client :
0 commit comments