33
44import dataclasses
55from datetime import datetime
6- from typing import TYPE_CHECKING , Any , AsyncIterator , Iterator , TypeVar , cast
6+ from typing import TYPE_CHECKING , Any , AsyncIterator , ClassVar , Iterator , TypeVar , cast
77
88from google .protobuf .wrappers_pb2 import Int64Value
99from yandex .cloud .ai .assistants .v1 .runs .run_pb2 import Run as ProtoRun
1616from yandex_cloud_ml_sdk ._tools .tool_result import (
1717 ProtoAssistantToolResultList , ToolResultInputType , tool_results_to_proto
1818)
19- from yandex_cloud_ml_sdk ._types .operation import OperationInterface
19+ from yandex_cloud_ml_sdk ._types .operation import AsyncOperationMixin , OperationInterface , SyncOperationMixin
2020from yandex_cloud_ml_sdk ._types .resource import BaseResource
2121from yandex_cloud_ml_sdk ._types .result import ProtoMessage
2222from yandex_cloud_ml_sdk ._types .schemas import ResponseType
3131
3232
3333@dataclasses .dataclass (frozen = True )
34- class BaseRun (BaseResource , OperationInterface [RunResult [ToolCallTypeT ]]):
34+ class BaseRun (BaseResource , OperationInterface [RunResult [ToolCallTypeT ], RunStatus ]):
3535 id : str
3636 assistant_id : str
3737 thread_id : str
@@ -43,6 +43,9 @@ class BaseRun(BaseResource, OperationInterface[RunResult[ToolCallTypeT]]):
4343 custom_prompt_truncation_options : PromptTruncationOptions | None
4444 custom_response_format : ResponseType | None
4545
46+ _default_poll_timeout : ClassVar [int ] = 300
47+ _default_poll_interval : ClassVar [float ] = 0.5
48+
4649 @property
4750 def custom_max_prompt_tokens (self ) -> int | None :
4851 if self .custom_prompt_truncation_options :
@@ -167,14 +170,15 @@ async def requests() -> AsyncIterator[AttachRunRequest]:
167170
168171 return
169172
173+ async def _cancel (
174+ self ,
175+ * ,
176+ timeout : float = 60
177+ ) -> None :
178+ raise NotImplementedError ("Run couldn't be cancelled" )
170179
171- class AsyncRun (BaseRun [AsyncToolCall ]):
172- async def get_status (self , * , timeout : float = 60 ) -> RunStatus :
173- return await self ._get_status (timeout = timeout )
174-
175- async def get_result (self , * , timeout : float = 60 ) -> RunResult [AsyncToolCall ]:
176- return await self ._get_result (timeout = timeout )
177180
181+ class AsyncRun (AsyncOperationMixin [RunResult [AsyncToolCall ], RunStatus ], BaseRun [AsyncToolCall ]):
178182 async def listen (
179183 self ,
180184 * ,
@@ -189,22 +193,6 @@ async def listen(
189193
190194 __aiter__ = listen
191195
192- async def wait (
193- self ,
194- * ,
195- timeout : float = 60 ,
196- poll_timeout : int = 300 ,
197- poll_interval : float = 0.5 ,
198- ) -> RunResult [AsyncToolCall ]:
199- return await self ._wait (
200- timeout = timeout ,
201- poll_timeout = poll_timeout ,
202- poll_interval = poll_interval ,
203- )
204-
205- def __await__ (self ):
206- return self .wait ().__await__ ()
207-
208196 async def submit_tool_results (
209197 self ,
210198 tool_results : ToolResultInputType ,
@@ -214,20 +202,11 @@ async def submit_tool_results(
214202 await super ()._submit_tool_results (tool_results = tool_results , timeout = timeout )
215203
216204
217- class Run (BaseRun [ToolCall ]):
218- __get_status = run_sync (BaseRun ._get_status )
219- __get_result = run_sync (BaseRun ._get_result )
220- __wait = run_sync (BaseRun ._wait )
205+ class Run (SyncOperationMixin [RunResult [ToolCall ], RunStatus ], BaseRun [ToolCall ]):
221206 __listen = run_sync_generator (BaseRun ._listen )
222207 __iter__ = __listen
223208 __submit_tool_results = run_sync (BaseRun ._submit_tool_results )
224209
225- def get_status (self , * , timeout : float = 60 ) -> RunStatus :
226- return self .__get_status (timeout = timeout )
227-
228- def get_result (self , * , timeout : float = 60 ) -> RunResult [ToolCall ]:
229- return self .__get_result (timeout = timeout )
230-
231210 def listen (
232211 self ,
233212 * ,
@@ -239,20 +218,6 @@ def listen(
239218 timeout = timeout ,
240219 )
241220
242- def wait (
243- self ,
244- * ,
245- timeout : float = 60 ,
246- poll_timeout : int = 300 ,
247- poll_interval : float = 0.5 ,
248- ) -> RunResult [ToolCall ]:
249- # NB: mypy can't unterstand normally __wait return type and thinks its ResultTypeT
250- return self .__wait ( # type: ignore[return-value]
251- timeout = timeout ,
252- poll_timeout = poll_timeout ,
253- poll_interval = poll_interval ,
254- )
255-
256221 def submit_tool_results (
257222 self ,
258223 tool_results : ToolResultInputType ,
0 commit comments