22from __future__ import annotations
33
44import dataclasses
5+ from collections .abc import Iterable
56from datetime import datetime
67from typing import TYPE_CHECKING , Any , AsyncIterator , Generic , Iterator , TypeVar
78
1213 UpdateAssistantRequest
1314)
1415from yandex .cloud .ai .assistants .v1 .assistant_service_pb2_grpc import AssistantServiceStub
16+ from yandex .cloud .ai .assistants .v1 .common_pb2 import Tool as ProtoAssistantsTool
1517
1618from yandex_cloud_ml_sdk ._models .completions .model import BaseGPTModel
1719from yandex_cloud_ml_sdk ._runs .run import AsyncRun , Run , RunTypeT
2022from yandex_cloud_ml_sdk ._types .expiration import ExpirationConfig , ExpirationPolicyAlias
2123from yandex_cloud_ml_sdk ._types .misc import UNDEFINED , UndefinedOr , get_defined_value , is_defined
2224from yandex_cloud_ml_sdk ._types .resource import ExpirableResource , safe_on_delete
25+ from yandex_cloud_ml_sdk ._utils .coerce import coerce_tuple
2326from yandex_cloud_ml_sdk ._utils .sync import run_sync_generator_impl , run_sync_impl
2427
2528from .utils import get_completion_options , get_prompt_trunctation_options
@@ -72,6 +75,7 @@ async def _update(
7275 description : UndefinedOr [str ] = UNDEFINED ,
7376 labels : UndefinedOr [dict [str , str ]] = UNDEFINED ,
7477 ttl_days : UndefinedOr [int ] = UNDEFINED ,
78+ tools : UndefinedOr [Iterable [BaseTool ]] = UNDEFINED ,
7579 expiration_policy : UndefinedOr [ExpirationPolicyAlias ] = UNDEFINED ,
7680 timeout : float = 60 ,
7781 ) -> Self :
@@ -83,6 +87,11 @@ async def _update(
8387
8488 model_uri : UndefinedOr [str ] | None = UNDEFINED
8589
90+ tools_ : tuple [BaseTool , ...] = ()
91+ if is_defined (tools ):
92+ # NB: mypy doesn't love abstract class used as TypeVar substitution here
93+ tools_ = coerce_tuple (tools , BaseTool ) # type: ignore[type-abstract]
94+
8695 if is_defined (model ):
8796 if isinstance (model , str ):
8897 model_uri = self ._sdk .models .completions (model ).uri
@@ -108,7 +117,8 @@ async def _update(
108117 completion_options = get_completion_options (
109118 temperature = temperature ,
110119 max_tokens = max_tokens ,
111- )
120+ ),
121+ tools = [tool ._to_proto (ProtoAssistantsTool ) for tool in tools_ ]
112122 )
113123 if model_uri and is_defined (model_uri ):
114124 request .model_uri = model_uri
@@ -126,6 +136,7 @@ async def _update(
126136 'completion_options.temperature' : temperature ,
127137 'completion_options.max_tokens' : max_tokens ,
128138 'prompt_truncation_options.max_prompt_tokens' : max_prompt_tokens ,
139+ 'tools' : tools ,
129140 }
130141 )
131142
@@ -286,6 +297,7 @@ async def update(
286297 description : UndefinedOr [str ] = UNDEFINED ,
287298 labels : UndefinedOr [dict [str , str ]] = UNDEFINED ,
288299 ttl_days : UndefinedOr [int ] = UNDEFINED ,
300+ tools : UndefinedOr [Iterable [BaseTool ]] = UNDEFINED ,
289301 expiration_policy : UndefinedOr [ExpirationPolicyAlias ] = UNDEFINED ,
290302 timeout : float = 60 ,
291303 ) -> Self :
@@ -299,6 +311,7 @@ async def update(
299311 description = description ,
300312 labels = labels ,
301313 ttl_days = ttl_days ,
314+ tools = tools ,
302315 expiration_policy = expiration_policy ,
303316 timeout = timeout
304317 )
@@ -371,6 +384,7 @@ def update(
371384 description : UndefinedOr [str ] = UNDEFINED ,
372385 labels : UndefinedOr [dict [str , str ]] = UNDEFINED ,
373386 ttl_days : UndefinedOr [int ] = UNDEFINED ,
387+ tools : UndefinedOr [Iterable [BaseTool ]] = UNDEFINED ,
374388 expiration_policy : UndefinedOr [ExpirationPolicyAlias ] = UNDEFINED ,
375389 timeout : float = 60 ,
376390 ) -> Self :
@@ -384,6 +398,7 @@ def update(
384398 description = description ,
385399 labels = labels ,
386400 ttl_days = ttl_days ,
401+ tools = tools ,
387402 expiration_policy = expiration_policy ,
388403 timeout = timeout
389404 ), self ._sdk )
0 commit comments