2525from yandex_cloud_ml_sdk ._utils .coerce import coerce_tuple
2626from yandex_cloud_ml_sdk ._utils .sync import run_sync_generator_impl , run_sync_impl
2727
28- from .utils import get_completion_options , get_prompt_trunctation_options
28+ from .prompt_truncation_options import PromptTruncationOptions , PromptTruncationStrategyType
29+ from .utils import get_completion_options
2930
3031if TYPE_CHECKING :
3132 from yandex_cloud_ml_sdk ._sdk import BaseSDK
@@ -36,9 +37,13 @@ class BaseAssistant(ExpirableResource, Generic[RunTypeT, ThreadTypeT]):
3637 expiration_config : ExpirationConfig
3738 model : BaseGPTModel
3839 instruction : str | None
39- max_prompt_tokens : int | None
40+ prompt_truncation_options : PromptTruncationOptions
4041 tools : tuple [BaseTool , ...]
4142
43+ @property
44+ def max_prompt_tokens (self ) -> int | None :
45+ return self .prompt_truncation_options .max_prompt_tokens
46+
4247 @classmethod
4348 def _kwargs_from_message (cls , proto : ProtoAssistant , sdk : BaseSDK ) -> dict [str , Any ]: # type: ignore[override]
4449 kwargs = super ()._kwargs_from_message (proto , sdk = sdk )
@@ -55,9 +60,10 @@ def _kwargs_from_message(cls, proto: ProtoAssistant, sdk: BaseSDK) -> dict[str,
5560 BaseTool ._from_upper_proto (tool , sdk = sdk )
5661 for tool in proto .tools
5762 )
58-
59- if max_prompt_tokens := proto .prompt_truncation_options .max_prompt_tokens .value :
60- kwargs ['max_prompt_tokens' ] = max_prompt_tokens
63+ kwargs ['prompt_truncation_options' ] = PromptTruncationOptions ._from_proto (
64+ proto = proto .prompt_truncation_options ,
65+ sdk = sdk
66+ )
6167
6268 return kwargs
6369
@@ -71,6 +77,7 @@ async def _update(
7177 max_tokens : UndefinedOr [int ] = UNDEFINED ,
7278 instruction : UndefinedOr [str ] = UNDEFINED ,
7379 max_prompt_tokens : UndefinedOr [int ] = UNDEFINED ,
80+ prompt_truncation_strategy : UndefinedOr [PromptTruncationStrategyType ] = UNDEFINED ,
7481 name : UndefinedOr [str ] = UNDEFINED ,
7582 description : UndefinedOr [str ] = UNDEFINED ,
7683 labels : UndefinedOr [dict [str , str ]] = UNDEFINED ,
@@ -104,16 +111,20 @@ async def _update(
104111 else :
105112 raise TypeError ('model argument must be str, GPTModel object either undefined' )
106113
114+ prompt_truncation_options = PromptTruncationOptions ._coerce (
115+ max_prompt_tokens = max_prompt_tokens ,
116+ strategy = prompt_truncation_strategy
117+ )
118+ proto_prompt_trunction_options = prompt_truncation_options ._to_proto ()
119+
107120 request = UpdateAssistantRequest (
108121 assistant_id = self .id ,
109122 name = get_defined_value (name , '' ),
110123 description = get_defined_value (description , '' ),
111124 labels = get_defined_value (labels , {}),
112125 instruction = get_defined_value (instruction , '' ),
113126 expiration_config = expiration_config .to_proto (),
114- prompt_truncation_options = get_prompt_trunctation_options (
115- max_prompt_tokens = get_defined_value (max_prompt_tokens , None )
116- ),
127+ prompt_truncation_options = proto_prompt_trunction_options ,
117128 completion_options = get_completion_options (
118129 temperature = temperature ,
119130 max_tokens = max_tokens ,
@@ -135,9 +146,8 @@ async def _update(
135146 'model_uri' : model_uri ,
136147 'completion_options.temperature' : temperature ,
137148 'completion_options.max_tokens' : max_tokens ,
138- 'prompt_truncation_options.max_prompt_tokens' : max_prompt_tokens ,
139149 'tools' : tools ,
140- }
150+ } | prompt_truncation_options . _get_update_paths ()
141151 )
142152
143153 async with self ._client .get_service_stub (AssistantServiceStub , timeout = timeout ) as stub :
@@ -215,6 +225,7 @@ async def _run_impl(
215225 custom_temperature : UndefinedOr [float ] = UNDEFINED ,
216226 custom_max_tokens : UndefinedOr [int ] = UNDEFINED ,
217227 custom_max_prompt_tokens : UndefinedOr [int ] = UNDEFINED ,
228+ custom_prompt_truncation_strategy : UndefinedOr [PromptTruncationStrategyType ] = UNDEFINED ,
218229 timeout : float = 60 ,
219230 ) -> RunTypeT :
220231 return await self ._sdk .runs ._create (
@@ -224,6 +235,7 @@ async def _run_impl(
224235 custom_temperature = custom_temperature ,
225236 custom_max_tokens = custom_max_tokens ,
226237 custom_max_prompt_tokens = custom_max_prompt_tokens ,
238+ custom_prompt_truncation_strategy = custom_prompt_truncation_strategy ,
227239 timeout = timeout ,
228240 )
229241
@@ -234,6 +246,7 @@ async def _run(
234246 custom_temperature : UndefinedOr [float ] = UNDEFINED ,
235247 custom_max_tokens : UndefinedOr [int ] = UNDEFINED ,
236248 custom_max_prompt_tokens : UndefinedOr [int ] = UNDEFINED ,
249+ custom_prompt_truncation_strategy : UndefinedOr [PromptTruncationStrategyType ] = UNDEFINED ,
237250 timeout : float = 60 ,
238251 ) -> RunTypeT :
239252 return await self ._run_impl (
@@ -242,6 +255,7 @@ async def _run(
242255 custom_temperature = custom_temperature ,
243256 custom_max_tokens = custom_max_tokens ,
244257 custom_max_prompt_tokens = custom_max_prompt_tokens ,
258+ custom_prompt_truncation_strategy = custom_prompt_truncation_strategy ,
245259 timeout = timeout ,
246260 )
247261
@@ -252,6 +266,7 @@ async def _run_stream(
252266 custom_temperature : UndefinedOr [float ] = UNDEFINED ,
253267 custom_max_tokens : UndefinedOr [int ] = UNDEFINED ,
254268 custom_max_prompt_tokens : UndefinedOr [int ] = UNDEFINED ,
269+ custom_prompt_truncation_strategy : UndefinedOr [PromptTruncationStrategyType ] = UNDEFINED ,
255270 timeout : float = 60 ,
256271 ) -> RunTypeT :
257272 return await self ._run_impl (
@@ -260,6 +275,7 @@ async def _run_stream(
260275 custom_temperature = custom_temperature ,
261276 custom_max_tokens = custom_max_tokens ,
262277 custom_max_prompt_tokens = custom_max_prompt_tokens ,
278+ custom_prompt_truncation_strategy = custom_prompt_truncation_strategy ,
263279 timeout = timeout ,
264280 )
265281
@@ -293,6 +309,7 @@ async def update(
293309 max_tokens : UndefinedOr [int ] = UNDEFINED ,
294310 instruction : UndefinedOr [str ] = UNDEFINED ,
295311 max_prompt_tokens : UndefinedOr [int ] = UNDEFINED ,
312+ prompt_truncation_strategy : UndefinedOr [PromptTruncationStrategyType ] = UNDEFINED ,
296313 name : UndefinedOr [str ] = UNDEFINED ,
297314 description : UndefinedOr [str ] = UNDEFINED ,
298315 labels : UndefinedOr [dict [str , str ]] = UNDEFINED ,
@@ -307,6 +324,7 @@ async def update(
307324 max_tokens = max_tokens ,
308325 instruction = instruction ,
309326 max_prompt_tokens = max_prompt_tokens ,
327+ prompt_truncation_strategy = prompt_truncation_strategy ,
310328 name = name ,
311329 description = description ,
312330 labels = labels ,
@@ -343,13 +361,15 @@ async def run(
343361 custom_temperature : UndefinedOr [float ] = UNDEFINED ,
344362 custom_max_tokens : UndefinedOr [int ] = UNDEFINED ,
345363 custom_max_prompt_tokens : UndefinedOr [int ] = UNDEFINED ,
364+ custom_prompt_truncation_strategy : UndefinedOr [PromptTruncationStrategyType ] = UNDEFINED ,
346365 timeout : float = 60 ,
347366 ) -> AsyncRun :
348367 return await self ._run (
349368 thread = thread ,
350369 custom_temperature = custom_temperature ,
351370 custom_max_tokens = custom_max_tokens ,
352371 custom_max_prompt_tokens = custom_max_prompt_tokens ,
372+ custom_prompt_truncation_strategy = custom_prompt_truncation_strategy ,
353373 timeout = timeout
354374 )
355375
@@ -360,13 +380,15 @@ async def run_stream(
360380 custom_temperature : UndefinedOr [float ] = UNDEFINED ,
361381 custom_max_tokens : UndefinedOr [int ] = UNDEFINED ,
362382 custom_max_prompt_tokens : UndefinedOr [int ] = UNDEFINED ,
383+ custom_prompt_truncation_strategy : UndefinedOr [PromptTruncationStrategyType ] = UNDEFINED ,
363384 timeout : float = 60 ,
364385 ) -> AsyncRun :
365386 return await self ._run_stream (
366387 thread = thread ,
367388 custom_temperature = custom_temperature ,
368389 custom_max_tokens = custom_max_tokens ,
369390 custom_max_prompt_tokens = custom_max_prompt_tokens ,
391+ custom_prompt_truncation_strategy = custom_prompt_truncation_strategy ,
370392 timeout = timeout
371393 )
372394
@@ -380,6 +402,7 @@ def update(
380402 max_tokens : UndefinedOr [int ] = UNDEFINED ,
381403 instruction : UndefinedOr [str ] = UNDEFINED ,
382404 max_prompt_tokens : UndefinedOr [int ] = UNDEFINED ,
405+ prompt_truncation_strategy : UndefinedOr [PromptTruncationStrategyType ] = UNDEFINED ,
383406 name : UndefinedOr [str ] = UNDEFINED ,
384407 description : UndefinedOr [str ] = UNDEFINED ,
385408 labels : UndefinedOr [dict [str , str ]] = UNDEFINED ,
@@ -394,6 +417,7 @@ def update(
394417 max_tokens = max_tokens ,
395418 instruction = instruction ,
396419 max_prompt_tokens = max_prompt_tokens ,
420+ prompt_truncation_strategy = prompt_truncation_strategy ,
397421 name = name ,
398422 description = description ,
399423 labels = labels ,
@@ -432,13 +456,15 @@ def run(
432456 custom_temperature : UndefinedOr [float ] = UNDEFINED ,
433457 custom_max_tokens : UndefinedOr [int ] = UNDEFINED ,
434458 custom_max_prompt_tokens : UndefinedOr [int ] = UNDEFINED ,
459+ custom_prompt_truncation_strategy : UndefinedOr [PromptTruncationStrategyType ] = UNDEFINED ,
435460 timeout : float = 60 ,
436461 ) -> Run :
437462 return run_sync_impl (self ._run (
438463 thread = thread ,
439464 custom_temperature = custom_temperature ,
440465 custom_max_tokens = custom_max_tokens ,
441466 custom_max_prompt_tokens = custom_max_prompt_tokens ,
467+ custom_prompt_truncation_strategy = custom_prompt_truncation_strategy ,
442468 timeout = timeout
443469 ), self ._sdk )
444470
@@ -449,13 +475,15 @@ def run_stream(
449475 custom_temperature : UndefinedOr [float ] = UNDEFINED ,
450476 custom_max_tokens : UndefinedOr [int ] = UNDEFINED ,
451477 custom_max_prompt_tokens : UndefinedOr [int ] = UNDEFINED ,
478+ custom_prompt_truncation_strategy : UndefinedOr [PromptTruncationStrategyType ] = UNDEFINED ,
452479 timeout : float = 60 ,
453480 ) -> Run :
454481 return run_sync_impl (self ._run_stream (
455482 thread = thread ,
456483 custom_temperature = custom_temperature ,
457484 custom_max_tokens = custom_max_tokens ,
458485 custom_max_prompt_tokens = custom_max_prompt_tokens ,
486+ custom_prompt_truncation_strategy = custom_prompt_truncation_strategy ,
459487 timeout = timeout
460488 ), self ._sdk )
461489
0 commit comments