Skip to content

Commit c1e54ee

Browse files
authored
Add extra_query parameter (#160)
1 parent a1a1571 commit c1e54ee

File tree

8 files changed

+263
-2
lines changed

8 files changed

+263
-2
lines changed

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
('py:class', 'yandex_cloud_ml_sdk._types.schemas.JsonArray'),
7777
('py:class', "'yandex_cloud_ml_sdk._types.schemas.JsonArray'"),
7878
('py:class', 'JsonObject'),
79+
('py:class', 'JsonArray'),
7980
('py:class', 'JsonSchemaType'),
8081
('py:class', 'ResponseType'),
8182
}

docs/types/other.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,21 @@ Miscellaneous types
8484
:no-inherited-members:
8585

8686
.. py:class:: yandex_cloud_ml_sdk._tools.tool_call_list.HttpToolCallList
87+
88+
89+
Completions-related types
90+
~~~~~~~~~~~~~~~~~~~~~~~~~
91+
92+
.. py:class:: yandex_cloud_ml_sdk._chat.completions.config.ChatReasoningModeType
93+
94+
.. autodata:: yandex_cloud_ml_sdk._chat.completions.config.ChatReasoningModeType
95+
96+
.. py:class:: yandex_cloud_ml_sdk._chat.completions.config.QueryType
97+
98+
.. py:class:: yandex_cloud_ml_sdk._models.completions.config.CompletionTool
99+
100+
.. autodata:: yandex_cloud_ml_sdk._models.completions.config.CompletionTool
101+
102+
.. py:class:: yandex_cloud_ml_sdk._types.tools.tool_choice.ToolChoiceType
103+
104+
.. autodata:: yandex_cloud_ml_sdk._types.tools.tool_choice.ToolChoiceType

examples/async/chat/extra_query.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
7+
from yandex_cloud_ml_sdk import AsyncYCloudML
8+
9+
10+
async def get_model(sdk: AsyncYCloudML):
11+
models = await sdk.chat.completions.list()
12+
i = 0
13+
print('You have access to the following models:')
14+
for i, model in enumerate(models):
15+
print(f" [{i:2}] {model.uri}")
16+
17+
raw_number = input(f"Please, input model number from 0 to {i}: ")
18+
number = int(raw_number)
19+
return models[number]
20+
21+
22+
async def main() -> None:
23+
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')
24+
sdk.setup_default_logging()
25+
26+
model = await get_model(sdk)
27+
28+
# You could pass any extra query parameters to the model
29+
# via extra_query configuration parameter
30+
model = model.configure(temperature=0.5, extra_query={'top_p': 0.2})
31+
32+
# Note that reconfiguring extra_query will rewrite it's value entirely
33+
# without any merging
34+
model = model.configure(extra_query={'top_k': 2})
35+
print(f"{model.config.extra_query=} {model.config.temperature=}")
36+
37+
request = 'Say random number from 0 to 10'
38+
for title, extra_query in (
39+
('deterministic', {'top_k': 2, 'top_p': 0.1}),
40+
('another deterministic', {'top_k': 2, 'top_p': 0.1}),
41+
('more random', {'top_k': 5, 'top_p': 1}),
42+
('another more random', {'top_k': 5, 'top_p': 1}),
43+
):
44+
model = model.configure(extra_query=extra_query)
45+
result = await model.run(request)
46+
print(f"{title} result: {result.text}")
47+
48+
# Also note that there is no client validation about extra query value at all:
49+
model = model.configure(extra_query={'foo': 2})
50+
# This will not fail:
51+
await model.run(request)
52+
# So, refer to models documentation to find out about extra model parameters
53+
54+
55+
if __name__ == '__main__':
56+
asyncio.run(main())

examples/sync/chat/extra_query.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
from yandex_cloud_ml_sdk import YCloudML
6+
7+
8+
def get_model(sdk: YCloudML):
9+
models = sdk.chat.completions.list()
10+
i = 0
11+
print('You have access to the following models:')
12+
for i, model in enumerate(models):
13+
print(f" [{i:2}] {model.uri}")
14+
15+
raw_number = input(f"Please, input model number from 0 to {i}: ")
16+
number = int(raw_number)
17+
return models[number]
18+
19+
20+
def main() -> None:
21+
sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64')
22+
sdk.setup_default_logging()
23+
24+
model = get_model(sdk)
25+
26+
# You could pass any extra query parameters to the model
27+
# via extra_query configuration parameter
28+
model = model.configure(temperature=0.5, extra_query={'top_p': 0.2})
29+
30+
# Note that reconfiguring extra_query will rewrite it's value entirely
31+
# without any merging
32+
model = model.configure(extra_query={'top_k': 2})
33+
print(f"{model.config.extra_query=} {model.config.temperature=}")
34+
35+
request = 'Say random number from 0 to 10'
36+
for title, extra_query in (
37+
('deterministic', {'top_k': 2, 'top_p': 0.1}),
38+
('another deterministic', {'top_k': 2, 'top_p': 0.1}),
39+
('more random', {'top_k': 5, 'top_p': 1}),
40+
('another more random', {'top_k': 5, 'top_p': 1}),
41+
):
42+
model = model.configure(extra_query=extra_query)
43+
result = model.run(request)
44+
print(f"{title} result: {result.text}")
45+
46+
# Also note that there is no client validation about extra query value at all:
47+
model = model.configure(extra_query={'foo': 2})
48+
# This will not fail:
49+
model.run(request)
50+
# So, refer to models documentation to find out about extra model parameters
51+
52+
53+
if __name__ == '__main__':
54+
main()

src/yandex_cloud_ml_sdk/_chat/completions/config.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from __future__ import annotations
22

3+
from copy import deepcopy
34
from dataclasses import dataclass
45
from enum import Enum
56
from typing import Any, Union
67

7-
from typing_extensions import Self
8+
from typing_extensions import Self, TypeAlias
89

910
from yandex_cloud_ml_sdk._models.completions.config import CompletionTool, GPTModelConfig
1011
from yandex_cloud_ml_sdk._tools.tool import BaseTool
12+
from yandex_cloud_ml_sdk._types.schemas import JsonObject
1113
from yandex_cloud_ml_sdk._utils.coerce import coerce_tuple
1214

1315

@@ -26,12 +28,14 @@ def _coerce(cls, value: ChatReasoningModeType) -> Self:
2628

2729

2830
ChatReasoningModeType = Union[str, ChatReasoningMode]
31+
QueryType: TypeAlias = JsonObject
2932

3033

3134
@dataclass(frozen=True)
3235
class ChatModelConfig(GPTModelConfig):
3336
reasoning_mode: ChatReasoningMode | None = None
3437
tools: tuple[CompletionTool, ...] | None = None
38+
extra_query: QueryType | None = None
3539

3640
def _replace(self, **kwargs: Any) -> Self:
3741
if reasoning_mode := kwargs.get('reasoning_mode'):
@@ -40,4 +44,9 @@ def _replace(self, **kwargs: Any) -> Self:
4044
if tools := kwargs.get('tools'):
4145
kwargs['tools'] = coerce_tuple(tools, BaseTool) # type: ignore[type-abstract]
4246

47+
extra_query: QueryType | None
48+
if extra_query := kwargs.get('extra_query'):
49+
assert isinstance(extra_query, dict)
50+
kwargs['extra_query'] = deepcopy(extra_query)
51+
4352
return super()._replace(**kwargs)

src/yandex_cloud_ml_sdk/_chat/completions/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from yandex_cloud_ml_sdk._types.tools.tool_choice import coerce_to_json as coerce_tool_choice_to_json
1616
from yandex_cloud_ml_sdk._utils.sync import run_sync, run_sync_generator
1717

18-
from .config import ChatModelConfig, ChatReasoningModeType
18+
from .config import ChatModelConfig, ChatReasoningModeType, QueryType
1919
from .message import ChatMessageInputType, messages_to_json
2020
from .result import ChatModelResult
2121

@@ -40,6 +40,7 @@ def configure( # type: ignore[override]
4040
tools: UndefinedOr[Sequence[CompletionTool] | CompletionTool] = UNDEFINED,
4141
parallel_tool_calls: UndefinedOr[bool] = UNDEFINED,
4242
tool_choice: UndefinedOr[ToolChoiceType] = UNDEFINED,
43+
extra_query: UndefinedOr[QueryType] = UNDEFINED,
4344
) -> Self:
4445
return super().configure(
4546
temperature=temperature,
@@ -49,6 +50,7 @@ def configure( # type: ignore[override]
4950
tools=tools,
5051
parallel_tool_calls=parallel_tool_calls,
5152
tool_choice=tool_choice,
53+
extra_query=extra_query,
5254
)
5355

5456
def _build_request_json(self, messages: ChatMessageInputType, stream: bool) -> dict[str, Any]:
@@ -85,6 +87,10 @@ def _build_request_json(self, messages: ChatMessageInputType, stream: bool) -> d
8587

8688
if c.tool_choice is not None:
8789
result['tool_choice'] = coerce_tool_choice_to_json(c.tool_choice)
90+
91+
if c.extra_query is not None:
92+
result.update(c.extra_query)
93+
8894
return result
8995

9096
@override
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
interactions:
2+
- request:
3+
body: '{"model":"gpt://b1ghsjum2v37c2un8h64/yandexgpt/latest","messages":[{"role":"user","content":"Say
4+
random number from 0 to 10"}],"stream":false}'
5+
headers:
6+
accept:
7+
- '*/*'
8+
accept-encoding:
9+
- gzip, deflate, zstd
10+
connection:
11+
- keep-alive
12+
content-length:
13+
- '142'
14+
content-type:
15+
- application/json
16+
host:
17+
- llm.api.cloud.yandex.net
18+
user-agent:
19+
- yandex-cloud-ml-sdk/0.15.0 python/3.12
20+
x-client-request-id:
21+
- e39a16aa-9ecc-4769-a220-d99a47e9997c
22+
method: POST
23+
uri: https://llm.api.cloud.yandex.net/v1/chat/completions
24+
response:
25+
body:
26+
string: '{"id":"09afef25-7de1-4a86-8cda-1cb673f2ddc1","object":"chat.completion","created":1758116486,"model":"gpt://b1ghsjum2v37c2un8h64/yandexgpt/latest","choices":[{"index":0,"message":{"role":"assistant","content":"7"},"finish_reason":"stop"}],"usage":{"prompt_tokens":20,"total_tokens":22,"completion_tokens":2}}
27+
28+
'
29+
headers:
30+
content-length:
31+
- '309'
32+
content-type:
33+
- application/json
34+
date:
35+
- Wed, 17 Sep 2025 13:41:26 GMT
36+
server:
37+
- ycalb
38+
x-client-request-id:
39+
- e39a16aa-9ecc-4769-a220-d99a47e9997c
40+
x-request-id:
41+
- c3402ee1-93c6-4c99-a660-f978ec004b35
42+
x-server-trace-id:
43+
- 990d344b241074c2:295875c4dbdaafe7:990d344b241074c2:1
44+
status:
45+
code: 200
46+
message: OK
47+
- request:
48+
body: '{"model":"gpt://b1ghsjum2v37c2un8h64/yandexgpt/latest","messages":[{"role":"user","content":"Say
49+
random number from 0 to 10"}],"stream":false,"top_k":3}'
50+
headers:
51+
accept:
52+
- '*/*'
53+
accept-encoding:
54+
- gzip, deflate, zstd
55+
connection:
56+
- keep-alive
57+
content-length:
58+
- '152'
59+
content-type:
60+
- application/json
61+
host:
62+
- llm.api.cloud.yandex.net
63+
user-agent:
64+
- yandex-cloud-ml-sdk/0.15.0 python/3.12
65+
x-client-request-id:
66+
- 5a710077-50b4-43ed-b269-78116ab15066
67+
method: POST
68+
uri: https://llm.api.cloud.yandex.net/v1/chat/completions
69+
response:
70+
body:
71+
string: '{"id":"1ee562fa-00fa-4d56-84a5-4bcea60cd694","object":"chat.completion","created":1758116487,"model":"gpt://b1ghsjum2v37c2un8h64/yandexgpt/latest","choices":[{"index":0,"message":{"role":"assistant","content":"7"},"finish_reason":"stop"}],"usage":{"prompt_tokens":20,"total_tokens":22,"completion_tokens":2}}
72+
73+
'
74+
headers:
75+
content-length:
76+
- '309'
77+
content-type:
78+
- application/json
79+
date:
80+
- Wed, 17 Sep 2025 13:41:27 GMT
81+
server:
82+
- ycalb
83+
x-client-request-id:
84+
- 5a710077-50b4-43ed-b269-78116ab15066
85+
x-request-id:
86+
- e3cb30bc-4cdb-4110-a4ba-90c9b2744aa9
87+
x-server-trace-id:
88+
- 16405981e4f1ec7d:64b3b33e662bcec:16405981e4f1ec7d:1
89+
status:
90+
code: 200
91+
message: OK
92+
version: 1

tests/chat/test_completions.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pathlib
66
from typing import cast
77

8+
import httpx._client
89
import pytest
910

1011
from yandex_cloud_ml_sdk import AsyncYCloudML
@@ -388,3 +389,27 @@ async def test_multimodal(async_sdk: AsyncYCloudML) -> None:
388389
]
389390
result = await model.run(request)
390391
assert 'complex' in result.text
392+
393+
394+
async def test_extra_query(async_sdk: AsyncYCloudML, monkeypatch) -> None:
395+
top_k = None
396+
397+
original = httpx._client.AsyncClient.request
398+
399+
async def patched_request(*args, **kwargs):
400+
nonlocal top_k
401+
top_k = kwargs.get('json', {}).get('top_k')
402+
return await original(*args, **kwargs)
403+
404+
monkeypatch.setattr("httpx._client.AsyncClient.request", patched_request)
405+
406+
query = "Say random number from 0 to 10"
407+
408+
model = async_sdk.chat.completions('yandexgpt')
409+
410+
await model.run(query)
411+
assert not top_k
412+
413+
model = model.configure(extra_query={'top_k': 3})
414+
await model.run(query)
415+
assert top_k == 3

0 commit comments

Comments
 (0)