99from typing_extensions import Self , override
1010from yandex .cloud .ai .foundation_models .v1 .text_common_pb2 import CompletionOptions , ReasoningOptions
1111from yandex .cloud .ai .foundation_models .v1 .text_common_pb2 import Tool as ProtoCompletionsTool
12+ from yandex .cloud .ai .foundation_models .v1 .text_common_pb2 import ToolChoice as ProtoToolChoice
1213from yandex .cloud .ai .foundation_models .v1 .text_generation .text_generation_service_pb2 import (
1314 BatchCompletionMetadata , BatchCompletionRequest , BatchCompletionResponse , CompletionRequest , CompletionResponse ,
1415 TokenizeResponse
2930)
3031from yandex_cloud_ml_sdk ._types .operation import AsyncOperation , Operation
3132from yandex_cloud_ml_sdk ._types .schemas import ResponseType , make_response_format_kwargs
33+ from yandex_cloud_ml_sdk ._types .tool_choice import ToolChoiceType
34+ from yandex_cloud_ml_sdk ._types .tool_choice import coerce_to_proto as coerce_to_proto_tool_choice
3235from yandex_cloud_ml_sdk ._types .tuning .datasets import TuningDatasetsType
3336from yandex_cloud_ml_sdk ._types .tuning .optimizers import BaseOptimizer
3437from yandex_cloud_ml_sdk ._types .tuning .schedulers import BaseScheduler
@@ -85,6 +88,7 @@ def configure( # type: ignore[override]
8588 response_format : UndefinedOr [ResponseType ] = UNDEFINED ,
8689 tools : UndefinedOr [Sequence [CompletionTool ] | CompletionTool ] = UNDEFINED ,
8790 parallel_tool_calls : UndefinedOr [bool ] = UNDEFINED ,
91+ tool_choice : UndefinedOr [ToolChoiceType ] = UNDEFINED ,
8892 ) -> Self :
8993 return super ().configure (
9094 temperature = temperature ,
@@ -93,6 +97,7 @@ def configure( # type: ignore[override]
9397 response_format = response_format ,
9498 tools = tools ,
9599 parallel_tool_calls = parallel_tool_calls ,
100+ tool_choice = tool_choice ,
96101 )
97102
98103 def _make_completion_options (self , * , stream : bool | None ) -> CompletionOptions :
@@ -132,17 +137,22 @@ def _make_request(
132137 if c .parallel_tool_calls is not None :
133138 parallel_tool_calls = BoolValue (value = c .parallel_tool_calls )
134139
140+ tool_choice : None | ProtoToolChoice = None
141+ if c .tool_choice is not None :
142+ tool_choice = coerce_to_proto_tool_choice (c .tool_choice , expected_type = ProtoToolChoice )
143+
135144 return CompletionRequest (
136145 model_uri = self ._uri ,
137146 completion_options = self ._make_completion_options (stream = stream ),
138147 messages = messages_to_proto (messages ),
139148 tools = [tool ._to_proto (ProtoCompletionsTool ) for tool in tools ],
140149 parallel_tool_calls = parallel_tool_calls ,
150+ tool_choice = tool_choice ,
141151 ** response_format_kwargs ,
142152 )
143153
144154 def _make_batch_request (self , dataset_id : str ) -> BatchCompletionRequest :
145- for field in ('tools' , 'response_format' ):
155+ for field in ('tools' , 'response_format' , 'tool_choice' , 'parallel_tool_calls' ):
146156 value = getattr (self .config , field )
147157 if value is not None :
148158 warnings .warn (
0 commit comments