Skip to content

Commit 5a07bc9

Browse files
authored
Refactor operations inheritance structure (#129)
1 parent 4883018 commit 5a07bc9

File tree

7 files changed

+139
-179
lines changed

7 files changed

+139
-179
lines changed

src/yandex_cloud_ml_sdk/_datasets/draft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ async def _validate_deferred(
119119
self._transform_operation_result,
120120
raise_on_validation_failure=raise_on_validation_failure,
121121
),
122-
default_poll_timeout=DEFAULT_OPERATION_POLL_TIMEOUT,
122+
custom_default_poll_timeout=DEFAULT_OPERATION_POLL_TIMEOUT,
123123
)
124124

125125
async def _upload_deferred(

src/yandex_cloud_ml_sdk/_runs/run.py

Lines changed: 14 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import dataclasses
55
from datetime import datetime
6-
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, TypeVar, cast
6+
from typing import TYPE_CHECKING, Any, AsyncIterator, ClassVar, Iterator, TypeVar, cast
77

88
from google.protobuf.wrappers_pb2 import Int64Value
99
from yandex.cloud.ai.assistants.v1.runs.run_pb2 import Run as ProtoRun
@@ -16,7 +16,7 @@
1616
from yandex_cloud_ml_sdk._tools.tool_result import (
1717
ProtoAssistantToolResultList, ToolResultInputType, tool_results_to_proto
1818
)
19-
from yandex_cloud_ml_sdk._types.operation import OperationInterface
19+
from yandex_cloud_ml_sdk._types.operation import AsyncOperationMixin, OperationInterface, SyncOperationMixin
2020
from yandex_cloud_ml_sdk._types.resource import BaseResource
2121
from yandex_cloud_ml_sdk._types.result import ProtoMessage
2222
from yandex_cloud_ml_sdk._types.schemas import ResponseType
@@ -31,7 +31,7 @@
3131

3232

3333
@dataclasses.dataclass(frozen=True)
34-
class BaseRun(BaseResource, OperationInterface[RunResult[ToolCallTypeT]]):
34+
class BaseRun(BaseResource, OperationInterface[RunResult[ToolCallTypeT], RunStatus]):
3535
id: str
3636
assistant_id: str
3737
thread_id: str
@@ -43,6 +43,9 @@ class BaseRun(BaseResource, OperationInterface[RunResult[ToolCallTypeT]]):
4343
custom_prompt_truncation_options: PromptTruncationOptions | None
4444
custom_response_format: ResponseType | None
4545

46+
_default_poll_timeout: ClassVar[int] = 300
47+
_default_poll_interval: ClassVar[float] = 0.5
48+
4649
@property
4750
def custom_max_prompt_tokens(self) -> int | None:
4851
if self.custom_prompt_truncation_options:
@@ -167,14 +170,15 @@ async def requests() -> AsyncIterator[AttachRunRequest]:
167170

168171
return
169172

173+
async def _cancel(
174+
self,
175+
*,
176+
timeout: float = 60
177+
) -> None:
178+
raise NotImplementedError("Run couldn't be cancelled")
170179

171-
class AsyncRun(BaseRun[AsyncToolCall]):
172-
async def get_status(self, *, timeout: float = 60) -> RunStatus:
173-
return await self._get_status(timeout=timeout)
174-
175-
async def get_result(self, *, timeout: float = 60) -> RunResult[AsyncToolCall]:
176-
return await self._get_result(timeout=timeout)
177180

181+
class AsyncRun(AsyncOperationMixin[RunResult[AsyncToolCall], RunStatus], BaseRun[AsyncToolCall]):
178182
async def listen(
179183
self,
180184
*,
@@ -189,22 +193,6 @@ async def listen(
189193

190194
__aiter__ = listen
191195

192-
async def wait(
193-
self,
194-
*,
195-
timeout: float = 60,
196-
poll_timeout: int = 300,
197-
poll_interval: float = 0.5,
198-
) -> RunResult[AsyncToolCall]:
199-
return await self._wait(
200-
timeout=timeout,
201-
poll_timeout=poll_timeout,
202-
poll_interval=poll_interval,
203-
)
204-
205-
def __await__(self):
206-
return self.wait().__await__()
207-
208196
async def submit_tool_results(
209197
self,
210198
tool_results: ToolResultInputType,
@@ -214,20 +202,11 @@ async def submit_tool_results(
214202
await super()._submit_tool_results(tool_results=tool_results, timeout=timeout)
215203

216204

217-
class Run(BaseRun[ToolCall]):
218-
__get_status = run_sync(BaseRun._get_status)
219-
__get_result = run_sync(BaseRun._get_result)
220-
__wait = run_sync(BaseRun._wait)
205+
class Run(SyncOperationMixin[RunResult[ToolCall], RunStatus], BaseRun[ToolCall]):
221206
__listen = run_sync_generator(BaseRun._listen)
222207
__iter__ = __listen
223208
__submit_tool_results = run_sync(BaseRun._submit_tool_results)
224209

225-
def get_status(self, *, timeout: float = 60) -> RunStatus:
226-
return self.__get_status(timeout=timeout)
227-
228-
def get_result(self, *, timeout: float = 60) -> RunResult[ToolCall]:
229-
return self.__get_result(timeout=timeout)
230-
231210
def listen(
232211
self,
233212
*,
@@ -239,20 +218,6 @@ def listen(
239218
timeout=timeout,
240219
)
241220

242-
def wait(
243-
self,
244-
*,
245-
timeout: float = 60,
246-
poll_timeout: int = 300,
247-
poll_interval: float = 0.5,
248-
) -> RunResult[ToolCall]:
249-
# NB: mypy can't unterstand normally __wait return type and thinks its ResultTypeT
250-
return self.__wait( # type: ignore[return-value]
251-
timeout=timeout,
252-
poll_timeout=poll_timeout,
253-
poll_interval=poll_interval,
254-
)
255-
256221
def submit_tool_results(
257222
self,
258223
tool_results: ToolResultInputType,

src/yandex_cloud_ml_sdk/_runs/status.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,21 @@
11
# pylint: disable=no-name-in-module
22
from __future__ import annotations
33

4-
import enum
4+
from enum import IntEnum
55

66
from yandex.cloud.ai.assistants.v1.runs.run_pb2 import RunState as ProtoRunState
77
from yandex.cloud.ai.assistants.v1.runs.run_service_pb2 import StreamEvent as ProtoStreamEvent
88

9+
from yandex_cloud_ml_sdk._types.operation import BaseOperationStatus
10+
from yandex_cloud_ml_sdk._utils.proto import ProtoEnumBase
911

10-
class BaseRunStatus:
11-
@property
12-
def is_running(self) -> bool:
13-
raise NotImplementedError()
1412

15-
@property
16-
def is_succeeded(self) -> bool:
17-
raise NotImplementedError()
18-
19-
@property
20-
def is_failed(self) -> bool:
21-
raise NotImplementedError()
13+
# pylint: disable=abstract-method
14+
class BaseRunStatus(BaseOperationStatus):
15+
pass
2216

2317

24-
class RunStatus(BaseRunStatus, int, enum.Enum):
18+
class RunStatus(BaseRunStatus, ProtoEnumBase, IntEnum):
2519
UNKNOWN = -1
2620
RUN_STATUS_UNSPECIFIED = ProtoRunState.RUN_STATUS_UNSPECIFIED
2721
PENDING = ProtoRunState.PENDING
@@ -42,15 +36,8 @@ def is_succeeded(self) -> bool:
4236
def is_failed(self) -> bool:
4337
return self is self.FAILED
4438

45-
@classmethod
46-
def _from_proto(cls, proto: int) -> RunStatus:
47-
try:
48-
return cls(proto)
49-
except ValueError:
50-
return cls(-1)
51-
5239

53-
class StreamEvent(BaseRunStatus, int, enum.Enum):
40+
class StreamEvent(BaseRunStatus, ProtoEnumBase, IntEnum):
5441
UNKNOWN = -1
5542
EVENT_TYPE_UNSPECIFIED = ProtoStreamEvent.EVENT_TYPE_UNSPECIFIED
5643
PARTIAL_MESSAGE = ProtoStreamEvent.PARTIAL_MESSAGE

src/yandex_cloud_ml_sdk/_tuning/tuning_task.py

Lines changed: 12 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from yandex.cloud.operation.operation_service_pb2_grpc import OperationServiceStub
1919

2020
from yandex_cloud_ml_sdk._logging import TRACE, get_logger
21-
from yandex_cloud_ml_sdk._types.operation import OperationErrorInfo, OperationInterface, OperationStatus
21+
from yandex_cloud_ml_sdk._types.operation import (
22+
AsyncOperationMixin, OperationErrorInfo, OperationInterface, OperationStatus, SyncOperationMixin
23+
)
2224
from yandex_cloud_ml_sdk._types.resource import BaseResource
2325
from yandex_cloud_ml_sdk._types.result import ProtoMessage
2426
from yandex_cloud_ml_sdk._utils.sync import run_sync
@@ -89,7 +91,7 @@ def _from_tuning_info(cls, info: TuningTaskInfo) -> TuningTaskStatus:
8991
)
9092

9193

92-
class BaseTuningTask(OperationInterface[TuningResultTypeT_co]):
94+
class BaseTuningTask(OperationInterface[TuningResultTypeT_co, TuningTaskStatus]):
9395
_sdk: BaseSDK
9496

9597
def __init__(
@@ -310,73 +312,27 @@ async def _get_metrics_url(self, *, timeout: float = 60) -> str | None:
310312
return None
311313

312314

313-
class AsyncTuningTask(BaseTuningTask[TuningResultTypeT_co]):
315+
class AsyncTuningTask(
316+
AsyncOperationMixin[TuningResultTypeT_co, TuningTaskStatus],
317+
BaseTuningTask[TuningResultTypeT_co]
318+
):
314319
async def get_task_info(self, *, timeout: float = 60) -> TuningTaskInfo | None:
315320
return await self._get_task_info(timeout=timeout)
316321

317-
async def get_status(self, *, timeout: float = 60) -> TuningTaskStatus:
318-
return await self._get_status(timeout=timeout)
319-
320-
async def get_result(self, *, timeout: float = 60) -> TuningResultTypeT_co:
321-
return await self._get_result(timeout=timeout)
322-
323-
async def cancel(self, *, timeout: float = 60) -> None:
324-
await self._cancel(timeout=timeout)
325-
326-
async def wait(
327-
self,
328-
*,
329-
timeout: float = 60,
330-
poll_timeout: int = 72 * 60 * 60,
331-
poll_interval: float = 10,
332-
) -> TuningResultTypeT_co:
333-
return await self._wait(
334-
timeout=timeout,
335-
poll_timeout=poll_timeout,
336-
poll_interval=poll_interval,
337-
)
338-
339322
async def get_metrics_url(self, *, timeout: float = 60) -> str | None:
340323
return await self._get_metrics_url(timeout=timeout)
341324

342-
def __await__(self):
343-
return self.wait().__await__()
344325

345-
346-
class TuningTask(BaseTuningTask[TuningResultTypeT_co]):
347-
__get_status = run_sync(BaseTuningTask._get_status)
348-
__get_result = run_sync(BaseTuningTask._get_result)
349-
__wait = run_sync(BaseTuningTask._wait)
350-
__cancel = run_sync(BaseTuningTask._cancel)
326+
class TuningTask(
327+
SyncOperationMixin[TuningResultTypeT_co, TuningTaskStatus],
328+
BaseTuningTask[TuningResultTypeT_co]
329+
):
351330
__get_metrics_url = run_sync(BaseTuningTask._get_metrics_url)
352331
__get_task_info = run_sync(BaseTuningTask._get_task_info)
353332

354333
def get_task_info(self, *, timeout: float = 60) -> TuningTaskInfo | None:
355334
return self.__get_task_info(timeout=timeout)
356335

357-
def get_status(self, *, timeout: float = 60) -> TuningTaskStatus:
358-
return self.__get_status(timeout=timeout)
359-
360-
def get_result(self, *, timeout: float = 60) -> TuningResultTypeT_co:
361-
return self.__get_result(timeout=timeout)
362-
363-
def cancel(self, *, timeout: float = 60) -> None:
364-
self.__cancel(timeout=timeout)
365-
366-
def wait(
367-
self,
368-
*,
369-
timeout: float = 60,
370-
poll_timeout: int = 72 * 60 * 60,
371-
poll_interval: float = 10,
372-
) -> TuningResultTypeT_co:
373-
result = self.__wait(
374-
timeout=timeout,
375-
poll_timeout=poll_timeout,
376-
poll_interval=poll_interval,
377-
)
378-
return cast(TuningResultTypeT_co, result)
379-
380336
def get_metrics_url(self, *, timeout: float = 60) -> str | None:
381337
return self.__get_metrics_url(timeout=timeout)
382338

src/yandex_cloud_ml_sdk/_types/batch/operation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
transformer=self._result_transformer,
4242
service_name='ai-foundation-models',
4343
initial_operation=initial_operation,
44-
default_poll_timeout=60 * 60 * 72, # 72h
44+
custom_default_poll_timeout=60 * 60 * 72, # 72h
4545
)
4646

4747
# NB: I don't want to make parent operation class Generic[MetadataTypeT] just to

0 commit comments

Comments
 (0)