Skip to content

Commit 0ca2b24

Browse files
authored
Add tool choice api (#125)
1 parent 3580919 commit 0ca2b24

File tree

9 files changed

+1494
-38
lines changed

9 files changed

+1494
-38
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
7+
from yandex_cloud_ml_sdk import AsyncYCloudML
8+
9+
SCHEMA = {
10+
"type": "object",
11+
"properties": {
12+
"expression": {
13+
"type": "string",
14+
"description": "The mathematical expression to evaluate (e.g., '2 + 3 * 4').",
15+
}
16+
},
17+
"required": ["expression"],
18+
}
19+
20+
21+
async def main() -> None:
22+
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')
23+
sdk.setup_default_logging()
24+
25+
calculator_tool = sdk.tools.function(
26+
name="calculator_tool",
27+
description="A simple calculator that performs basic arithmetic and @ operations.",
28+
parameters=SCHEMA # type: ignore[arg-type]
29+
)
30+
another_calculator = sdk.tools.function(
31+
name="another_calculator",
32+
description="A simple calculator that performs basic arithmetic and % operations.",
33+
parameters=SCHEMA # type: ignore[arg-type]
34+
)
35+
36+
model = sdk.models.completions('yandexgpt', model_version='rc').configure(
37+
tools=[calculator_tool, another_calculator],
38+
temperature=0,
39+
# auto is equivalent to default
40+
# tool_choice='auto'
41+
)
42+
43+
request = "How much it would be 7@8?"
44+
result = await model.run(request)
45+
46+
# Model could call the tool, but it depends on many things, for example - model version.
47+
# Right now I writing this example it does not calling the tool
48+
assert result.status.name == 'FINAL'
49+
50+
# You could configure that you don't want to call any tool
51+
model = model.configure(tool_choice='none')
52+
result = await model.run(request)
53+
assert result.status.name == 'FINAL'
54+
55+
# You could configure the model to always call some tool
56+
model = model.configure(tool_choice='required')
57+
result = await model.run(request)
58+
assert result.status.name =='TOOL_CALLS'
59+
assert result.tool_calls
60+
assert len(result.tool_calls) == 1
61+
assert result.tool_calls[0].function
62+
assert result.tool_calls[0].function.name == 'calculator_tool'
63+
64+
# Or configure to call specific tool
65+
model = model.configure(tool_choice={'type': 'function', 'function': {'name': 'another_calculator'}})
66+
# You could pass just a function tool object instead of this big dict
67+
model = model.configure(tool_choice=another_calculator)
68+
result = await model.run(request)
69+
assert result.status.name =='TOOL_CALLS'
70+
assert result.tool_calls
71+
assert len(result.tool_calls) == 1
72+
assert result.tool_calls[0].function
73+
assert result.tool_calls[0].function.name == 'another_calculator'
74+
75+
76+
if __name__ == '__main__':
77+
asyncio.run(main())
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
from yandex_cloud_ml_sdk import YCloudML
6+
7+
SCHEMA = {
8+
"type": "object",
9+
"properties": {
10+
"expression": {
11+
"type": "string",
12+
"description": "The mathematical expression to evaluate (e.g., '2 + 3 * 4').",
13+
}
14+
},
15+
"required": ["expression"],
16+
}
17+
18+
19+
def main() -> None:
20+
sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64')
21+
sdk.setup_default_logging()
22+
23+
calculator_tool = sdk.tools.function(
24+
name="calculator_tool",
25+
description="A simple calculator that performs basic arithmetic and @ operations.",
26+
parameters=SCHEMA # type: ignore[arg-type]
27+
)
28+
another_calculator = sdk.tools.function(
29+
name="another_calculator",
30+
description="A simple calculator that performs basic arithmetic and % operations.",
31+
parameters=SCHEMA # type: ignore[arg-type]
32+
)
33+
34+
model = sdk.models.completions('yandexgpt', model_version='rc').configure(
35+
tools=[calculator_tool, another_calculator],
36+
temperature=0,
37+
# auto is equivalent to default
38+
# tool_choice='auto'
39+
)
40+
41+
request = "How much it would be 7@8?"
42+
result = model.run(request)
43+
44+
# Model could call the tool, but it depends on many things, for example - model version.
45+
# Right now I writing this example it does not calling the tool
46+
assert result.status.name == 'FINAL'
47+
48+
# You could configure that you don't want to call any tool
49+
model = model.configure(tool_choice='none')
50+
result = model.run(request)
51+
assert result.status.name == 'FINAL'
52+
53+
# You could configure the model to always call some tool
54+
model = model.configure(tool_choice='required')
55+
result = model.run(request)
56+
assert result.status.name =='TOOL_CALLS'
57+
assert result.tool_calls
58+
assert len(result.tool_calls) == 1
59+
assert result.tool_calls[0].function
60+
assert result.tool_calls[0].function.name == 'calculator_tool'
61+
62+
# Or configure to call specific tool
63+
model = model.configure(tool_choice={'type': 'function', 'function': {'name': 'another_calculator'}})
64+
# You could pass just a function tool object instead of this big dict
65+
model = model.configure(tool_choice=another_calculator)
66+
result = model.run(request)
67+
assert result.status.name =='TOOL_CALLS'
68+
assert result.tool_calls
69+
assert len(result.tool_calls) == 1
70+
assert result.tool_calls[0].function
71+
assert result.tool_calls[0].function.name == 'another_calculator'
72+
73+
74+
if __name__ == '__main__':
75+
main()

src/yandex_cloud_ml_sdk/_models/completions/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from yandex_cloud_ml_sdk._tools.tool import FunctionTool
1313
from yandex_cloud_ml_sdk._types.model_config import BaseModelConfig
1414
from yandex_cloud_ml_sdk._types.schemas import ResponseType
15+
from yandex_cloud_ml_sdk._types.tool_choice import ToolChoiceType
1516
from yandex_cloud_ml_sdk._utils.proto import ProtoEnumBase
1617

1718
_m = ProtoReasoningOptions.ReasoningMode
@@ -35,3 +36,4 @@ class GPTModelConfig(BaseModelConfig):
3536
response_format: ResponseType | None = None
3637
tools: Sequence[CompletionTool] | CompletionTool | None = None
3738
parallel_tool_calls: bool | None = None
39+
tool_choice: ToolChoiceType | None = None

src/yandex_cloud_ml_sdk/_models/completions/model.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing_extensions import Self, override
1010
from yandex.cloud.ai.foundation_models.v1.text_common_pb2 import CompletionOptions, ReasoningOptions
1111
from yandex.cloud.ai.foundation_models.v1.text_common_pb2 import Tool as ProtoCompletionsTool
12+
from yandex.cloud.ai.foundation_models.v1.text_common_pb2 import ToolChoice as ProtoToolChoice
1213
from yandex.cloud.ai.foundation_models.v1.text_generation.text_generation_service_pb2 import (
1314
BatchCompletionMetadata, BatchCompletionRequest, BatchCompletionResponse, CompletionRequest, CompletionResponse,
1415
TokenizeResponse
@@ -29,6 +30,8 @@
2930
)
3031
from yandex_cloud_ml_sdk._types.operation import AsyncOperation, Operation
3132
from yandex_cloud_ml_sdk._types.schemas import ResponseType, make_response_format_kwargs
33+
from yandex_cloud_ml_sdk._types.tool_choice import ToolChoiceType
34+
from yandex_cloud_ml_sdk._types.tool_choice import coerce_to_proto as coerce_to_proto_tool_choice
3235
from yandex_cloud_ml_sdk._types.tuning.datasets import TuningDatasetsType
3336
from yandex_cloud_ml_sdk._types.tuning.optimizers import BaseOptimizer
3437
from yandex_cloud_ml_sdk._types.tuning.schedulers import BaseScheduler
@@ -85,6 +88,7 @@ def configure( # type: ignore[override]
8588
response_format: UndefinedOr[ResponseType] = UNDEFINED,
8689
tools: UndefinedOr[Sequence[CompletionTool] | CompletionTool] = UNDEFINED,
8790
parallel_tool_calls: UndefinedOr[bool] = UNDEFINED,
91+
tool_choice: UndefinedOr[ToolChoiceType] = UNDEFINED,
8892
) -> Self:
8993
return super().configure(
9094
temperature=temperature,
@@ -93,6 +97,7 @@ def configure( # type: ignore[override]
9397
response_format=response_format,
9498
tools=tools,
9599
parallel_tool_calls=parallel_tool_calls,
100+
tool_choice=tool_choice,
96101
)
97102

98103
def _make_completion_options(self, *, stream: bool | None) -> CompletionOptions:
@@ -132,17 +137,22 @@ def _make_request(
132137
if c.parallel_tool_calls is not None:
133138
parallel_tool_calls = BoolValue(value=c.parallel_tool_calls)
134139

140+
tool_choice: None | ProtoToolChoice = None
141+
if c.tool_choice is not None:
142+
tool_choice = coerce_to_proto_tool_choice(c.tool_choice, expected_type=ProtoToolChoice)
143+
135144
return CompletionRequest(
136145
model_uri=self._uri,
137146
completion_options=self._make_completion_options(stream=stream),
138147
messages=messages_to_proto(messages),
139148
tools=[tool._to_proto(ProtoCompletionsTool) for tool in tools],
140149
parallel_tool_calls=parallel_tool_calls,
150+
tool_choice=tool_choice,
141151
**response_format_kwargs,
142152
)
143153

144154
def _make_batch_request(self, dataset_id: str) -> BatchCompletionRequest:
145-
for field in ('tools', 'response_format'):
155+
for field in ('tools', 'response_format', 'tool_choice', 'parallel_tool_calls'):
146156
value = getattr(self.config, field)
147157
if value is not None:
148158
warnings.warn(
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from __future__ import annotations
2+
3+
from typing import Literal, TypedDict, TypeVar, Union, cast
4+
5+
from typing_extensions import TypeAlias
6+
# pylint: disable=no-name-in-module
7+
from yandex.cloud.ai.foundation_models.v1.text_common_pb2 import ToolChoice as ProtoCompletionsToolChoice
8+
9+
from yandex_cloud_ml_sdk._tools.tool import FunctionTool
10+
11+
ProtoToolChoice: TypeAlias = ProtoCompletionsToolChoice
12+
ProtoToolChoiceTypeT = TypeVar('ProtoToolChoiceTypeT', bound=ProtoToolChoice)
13+
14+
15+
class FunctionNameType(TypedDict):
16+
name: str
17+
18+
19+
class ToolChoiceDictType(TypedDict):
20+
type: Literal['function']
21+
function: FunctionNameType
22+
23+
24+
ToolChoiceStringType: TypeAlias = Literal[
25+
'none', 'None', 'NONE',
26+
'auto', 'Auto', 'AUTO',
27+
'required', 'Required', 'REQUIRED'
28+
]
29+
30+
ToolChoiceType: TypeAlias = Union[ToolChoiceStringType, ToolChoiceDictType, FunctionTool]
31+
32+
STRING_TOOL_CHOICES = ('NONE', 'AUTO', 'REQUIRED')
33+
34+
35+
def coerce_to_proto(
36+
tool_choice: ToolChoiceType, expected_type: type[ProtoToolChoiceTypeT]
37+
) -> ProtoToolChoiceTypeT:
38+
if isinstance(tool_choice, str):
39+
tool_choice = cast(ToolChoiceStringType, tool_choice.upper())
40+
if tool_choice not in STRING_TOOL_CHOICES:
41+
raise ValueError(f'wrong {tool_choice=}, use one of {STRING_TOOL_CHOICES}')
42+
43+
tool_choice_value = expected_type.ToolChoiceMode.Value(tool_choice)
44+
45+
return expected_type(mode=tool_choice_value)
46+
47+
if isinstance(tool_choice, dict):
48+
if (
49+
tool_choice.get('type') != 'function' or
50+
not isinstance(tool_choice.get('function'), dict) or
51+
not isinstance(tool_choice['function'].get('name'), str)
52+
):
53+
raise ValueError(
54+
'wrong dict structure for tool_choice, expected '
55+
'`{"type": "function", "function": {"name": function_name}}`, '
56+
'got {tool_choice}'
57+
)
58+
59+
tool_choice = cast(ToolChoiceDictType, tool_choice)
60+
61+
return expected_type(function_name=tool_choice['function']['name'])
62+
63+
if isinstance(tool_choice, FunctionTool):
64+
return expected_type(function_name=tool_choice.name)
65+
66+
raise TypeError(f'wrong {type(tool_choice)=}, expected string or dict')

src/yandex_cloud_ml_sdk/_utils/proto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _coerce(cls, value: str | int | ProtoEnumBase) -> Self:
101101
if isinstance(value, str):
102102
if member := cls.__members__.get(value.upper()):
103103
return member
104-
raise ValueError(f'wrong value "{value}" for use as an alisas for {cls}')
104+
raise ValueError(f'wrong value "{value}" for use as an alias for {cls}')
105105
raise TypeError(f'wrong type "{type(value)}" for use as an alias for {cls}')
106106

107107
def _to_proto(self) -> int:

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def patch_operation(request, monkeypatch):
5656

5757
@pytest.fixture(name='folder_id')
5858
def fixture_folder_id():
59-
return 'yc.fomo.storage.prod.service'
59+
return 'b1ghsjum2v37c2un8h64'
6060

6161

6262
@pytest.fixture(name='servicers')

0 commit comments

Comments
 (0)