diff --git a/src/yandex_cloud_ml_sdk/_datasets/draft.py b/src/yandex_cloud_ml_sdk/_datasets/draft.py index 2d773f86..ce56c8a7 100644 --- a/src/yandex_cloud_ml_sdk/_datasets/draft.py +++ b/src/yandex_cloud_ml_sdk/_datasets/draft.py @@ -119,7 +119,7 @@ async def _validate_deferred( self._transform_operation_result, raise_on_validation_failure=raise_on_validation_failure, ), - default_poll_timeout=DEFAULT_OPERATION_POLL_TIMEOUT, + custom_default_poll_timeout=DEFAULT_OPERATION_POLL_TIMEOUT, ) async def _upload_deferred( diff --git a/src/yandex_cloud_ml_sdk/_runs/run.py b/src/yandex_cloud_ml_sdk/_runs/run.py index 22f8ad74..c8632549 100644 --- a/src/yandex_cloud_ml_sdk/_runs/run.py +++ b/src/yandex_cloud_ml_sdk/_runs/run.py @@ -3,7 +3,7 @@ import dataclasses from datetime import datetime -from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, TypeVar, cast +from typing import TYPE_CHECKING, Any, AsyncIterator, ClassVar, Iterator, TypeVar, cast from google.protobuf.wrappers_pb2 import Int64Value from yandex.cloud.ai.assistants.v1.runs.run_pb2 import Run as ProtoRun @@ -16,7 +16,7 @@ from yandex_cloud_ml_sdk._tools.tool_result import ( ProtoAssistantToolResultList, ToolResultInputType, tool_results_to_proto ) -from yandex_cloud_ml_sdk._types.operation import OperationInterface +from yandex_cloud_ml_sdk._types.operation import AsyncOperationMixin, OperationInterface, SyncOperationMixin from yandex_cloud_ml_sdk._types.resource import BaseResource from yandex_cloud_ml_sdk._types.result import ProtoMessage from yandex_cloud_ml_sdk._types.schemas import ResponseType @@ -31,7 +31,7 @@ @dataclasses.dataclass(frozen=True) -class BaseRun(BaseResource, OperationInterface[RunResult[ToolCallTypeT]]): +class BaseRun(BaseResource, OperationInterface[RunResult[ToolCallTypeT], RunStatus]): id: str assistant_id: str thread_id: str @@ -43,6 +43,9 @@ class BaseRun(BaseResource, OperationInterface[RunResult[ToolCallTypeT]]): custom_prompt_truncation_options: PromptTruncationOptions | None custom_response_format: ResponseType | None + _default_poll_timeout: ClassVar[int] = 300 + _default_poll_interval: ClassVar[float] = 0.5 + @property def custom_max_prompt_tokens(self) -> int | None: if self.custom_prompt_truncation_options: @@ -167,14 +170,15 @@ async def requests() -> AsyncIterator[AttachRunRequest]: return + async def _cancel( + self, + *, + timeout: float = 60 + ) -> None: + raise NotImplementedError("Run couldn't be cancelled") -class AsyncRun(BaseRun[AsyncToolCall]): - async def get_status(self, *, timeout: float = 60) -> RunStatus: - return await self._get_status(timeout=timeout) - - async def get_result(self, *, timeout: float = 60) -> RunResult[AsyncToolCall]: - return await self._get_result(timeout=timeout) +class AsyncRun(AsyncOperationMixin[RunResult[AsyncToolCall], RunStatus], BaseRun[AsyncToolCall]): async def listen( self, *, @@ -189,22 +193,6 @@ async def listen( __aiter__ = listen - async def wait( - self, - *, - timeout: float = 60, - poll_timeout: int = 300, - poll_interval: float = 0.5, - ) -> RunResult[AsyncToolCall]: - return await self._wait( - timeout=timeout, - poll_timeout=poll_timeout, - poll_interval=poll_interval, - ) - - def __await__(self): - return self.wait().__await__() - async def submit_tool_results( self, tool_results: ToolResultInputType, @@ -214,20 +202,11 @@ async def submit_tool_results( await super()._submit_tool_results(tool_results=tool_results, timeout=timeout) -class Run(BaseRun[ToolCall]): - __get_status = run_sync(BaseRun._get_status) - __get_result = run_sync(BaseRun._get_result) - __wait = run_sync(BaseRun._wait) +class Run(SyncOperationMixin[RunResult[ToolCall], RunStatus], BaseRun[ToolCall]): __listen = run_sync_generator(BaseRun._listen) __iter__ = __listen __submit_tool_results = run_sync(BaseRun._submit_tool_results) - def get_status(self, *, timeout: float = 60) -> RunStatus: - return self.__get_status(timeout=timeout) - - def get_result(self, *, timeout: float = 60) -> RunResult[ToolCall]: - return self.__get_result(timeout=timeout) - def listen( self, *, @@ -239,20 +218,6 @@ def listen( timeout=timeout, ) - def wait( - self, - *, - timeout: float = 60, - poll_timeout: int = 300, - poll_interval: float = 0.5, - ) -> RunResult[ToolCall]: - # NB: mypy can't unterstand normally __wait return type and thinks its ResultTypeT - return self.__wait( # type: ignore[return-value] - timeout=timeout, - poll_timeout=poll_timeout, - poll_interval=poll_interval, - ) - def submit_tool_results( self, tool_results: ToolResultInputType, diff --git a/src/yandex_cloud_ml_sdk/_runs/status.py b/src/yandex_cloud_ml_sdk/_runs/status.py index 45e4b3fe..5cffada7 100644 --- a/src/yandex_cloud_ml_sdk/_runs/status.py +++ b/src/yandex_cloud_ml_sdk/_runs/status.py @@ -1,27 +1,21 @@ # pylint: disable=no-name-in-module from __future__ import annotations -import enum +from enum import IntEnum from yandex.cloud.ai.assistants.v1.runs.run_pb2 import RunState as ProtoRunState from yandex.cloud.ai.assistants.v1.runs.run_service_pb2 import StreamEvent as ProtoStreamEvent +from yandex_cloud_ml_sdk._types.operation import BaseOperationStatus +from yandex_cloud_ml_sdk._utils.proto import ProtoEnumBase -class BaseRunStatus: - @property - def is_running(self) -> bool: - raise NotImplementedError() - @property - def is_succeeded(self) -> bool: - raise NotImplementedError() - - @property - def is_failed(self) -> bool: - raise NotImplementedError() +# pylint: disable=abstract-method +class BaseRunStatus(BaseOperationStatus): + pass -class RunStatus(BaseRunStatus, int, enum.Enum): +class RunStatus(BaseRunStatus, ProtoEnumBase, IntEnum): UNKNOWN = -1 RUN_STATUS_UNSPECIFIED = ProtoRunState.RUN_STATUS_UNSPECIFIED PENDING = ProtoRunState.PENDING @@ -42,15 +36,8 @@ def is_succeeded(self) -> bool: def is_failed(self) -> bool: return self is self.FAILED - @classmethod - def _from_proto(cls, proto: int) -> RunStatus: - try: - return cls(proto) - except ValueError: - return cls(-1) - -class StreamEvent(BaseRunStatus, int, enum.Enum): +class StreamEvent(BaseRunStatus, ProtoEnumBase, IntEnum): UNKNOWN = -1 EVENT_TYPE_UNSPECIFIED = ProtoStreamEvent.EVENT_TYPE_UNSPECIFIED PARTIAL_MESSAGE = ProtoStreamEvent.PARTIAL_MESSAGE diff --git a/src/yandex_cloud_ml_sdk/_tuning/tuning_task.py b/src/yandex_cloud_ml_sdk/_tuning/tuning_task.py index b492baee..a95c9a90 100644 --- a/src/yandex_cloud_ml_sdk/_tuning/tuning_task.py +++ b/src/yandex_cloud_ml_sdk/_tuning/tuning_task.py @@ -18,7 +18,9 @@ from yandex.cloud.operation.operation_service_pb2_grpc import OperationServiceStub from yandex_cloud_ml_sdk._logging import TRACE, get_logger -from yandex_cloud_ml_sdk._types.operation import OperationErrorInfo, OperationInterface, OperationStatus +from yandex_cloud_ml_sdk._types.operation import ( + AsyncOperationMixin, OperationErrorInfo, OperationInterface, OperationStatus, SyncOperationMixin +) from yandex_cloud_ml_sdk._types.resource import BaseResource from yandex_cloud_ml_sdk._types.result import ProtoMessage from yandex_cloud_ml_sdk._utils.sync import run_sync @@ -89,7 +91,7 @@ def _from_tuning_info(cls, info: TuningTaskInfo) -> TuningTaskStatus: ) -class BaseTuningTask(OperationInterface[TuningResultTypeT_co]): +class BaseTuningTask(OperationInterface[TuningResultTypeT_co, TuningTaskStatus]): _sdk: BaseSDK def __init__( @@ -310,73 +312,27 @@ async def _get_metrics_url(self, *, timeout: float = 60) -> str | None: return None -class AsyncTuningTask(BaseTuningTask[TuningResultTypeT_co]): +class AsyncTuningTask( + AsyncOperationMixin[TuningResultTypeT_co, TuningTaskStatus], + BaseTuningTask[TuningResultTypeT_co] +): async def get_task_info(self, *, timeout: float = 60) -> TuningTaskInfo | None: return await self._get_task_info(timeout=timeout) - async def get_status(self, *, timeout: float = 60) -> TuningTaskStatus: - return await self._get_status(timeout=timeout) - - async def get_result(self, *, timeout: float = 60) -> TuningResultTypeT_co: - return await self._get_result(timeout=timeout) - - async def cancel(self, *, timeout: float = 60) -> None: - await self._cancel(timeout=timeout) - - async def wait( - self, - *, - timeout: float = 60, - poll_timeout: int = 72 * 60 * 60, - poll_interval: float = 10, - ) -> TuningResultTypeT_co: - return await self._wait( - timeout=timeout, - poll_timeout=poll_timeout, - poll_interval=poll_interval, - ) - async def get_metrics_url(self, *, timeout: float = 60) -> str | None: return await self._get_metrics_url(timeout=timeout) - def __await__(self): - return self.wait().__await__() - -class TuningTask(BaseTuningTask[TuningResultTypeT_co]): - __get_status = run_sync(BaseTuningTask._get_status) - __get_result = run_sync(BaseTuningTask._get_result) - __wait = run_sync(BaseTuningTask._wait) - __cancel = run_sync(BaseTuningTask._cancel) +class TuningTask( + SyncOperationMixin[TuningResultTypeT_co, TuningTaskStatus], + BaseTuningTask[TuningResultTypeT_co] +): __get_metrics_url = run_sync(BaseTuningTask._get_metrics_url) __get_task_info = run_sync(BaseTuningTask._get_task_info) def get_task_info(self, *, timeout: float = 60) -> TuningTaskInfo | None: return self.__get_task_info(timeout=timeout) - def get_status(self, *, timeout: float = 60) -> TuningTaskStatus: - return self.__get_status(timeout=timeout) - - def get_result(self, *, timeout: float = 60) -> TuningResultTypeT_co: - return self.__get_result(timeout=timeout) - - def cancel(self, *, timeout: float = 60) -> None: - self.__cancel(timeout=timeout) - - def wait( - self, - *, - timeout: float = 60, - poll_timeout: int = 72 * 60 * 60, - poll_interval: float = 10, - ) -> TuningResultTypeT_co: - result = self.__wait( - timeout=timeout, - poll_timeout=poll_timeout, - poll_interval=poll_interval, - ) - return cast(TuningResultTypeT_co, result) - def get_metrics_url(self, *, timeout: float = 60) -> str | None: return self.__get_metrics_url(timeout=timeout) diff --git a/src/yandex_cloud_ml_sdk/_types/batch/operation.py b/src/yandex_cloud_ml_sdk/_types/batch/operation.py index 254f0649..28fc10e2 100644 --- a/src/yandex_cloud_ml_sdk/_types/batch/operation.py +++ b/src/yandex_cloud_ml_sdk/_types/batch/operation.py @@ -41,7 +41,7 @@ def __init__( transformer=self._result_transformer, service_name='ai-foundation-models', initial_operation=initial_operation, - default_poll_timeout=60 * 60 * 72, # 72h + custom_default_poll_timeout=60 * 60 * 72, # 72h ) # NB: I don't want to make parent operation class Generic[MetadataTypeT] just to diff --git a/src/yandex_cloud_ml_sdk/_types/operation.py b/src/yandex_cloud_ml_sdk/_types/operation.py index 52d7b2a8..362e4269 100644 --- a/src/yandex_cloud_ml_sdk/_types/operation.py +++ b/src/yandex_cloud_ml_sdk/_types/operation.py @@ -4,7 +4,7 @@ import abc import asyncio from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Generic, Iterable, TypeVar, cast, get_origin +from typing import TYPE_CHECKING, Any, Awaitable, Callable, ClassVar, Generic, Iterable, TypeVar, cast, get_origin from google.protobuf.message import Message from typing_extensions import Self @@ -15,7 +15,7 @@ from yandex.cloud.operation.operation_service_pb2_grpc import OperationServiceStub from yandex_cloud_ml_sdk._logging import TRACE, get_logger -from yandex_cloud_ml_sdk._utils.sync import run_sync +from yandex_cloud_ml_sdk._utils.sync import run_sync_impl from yandex_cloud_ml_sdk.exceptions import RunError, WrongAsyncOperationStatusError from .proto import ProtoBasedType @@ -29,6 +29,29 @@ ResultTypeT_co = TypeVar('ResultTypeT_co', covariant=True) +# NB: it couldn't be ABC because it descendants can't inherit from ABC and Enum at the same time +class BaseOperationStatus: + @property + def is_running(self) -> bool: + raise NotImplementedError() + + @property + def is_succeeded(self) -> bool: + raise NotImplementedError() + + @property + def is_failed(self) -> bool: + raise NotImplementedError() + + @property + def status_name(self) -> str: + if self.is_succeeded: + return 'success' + if self.is_failed: + return 'failed' + return 'runnning' + + @dataclass(frozen=True) class OperationErrorInfo: code: int @@ -37,7 +60,7 @@ class OperationErrorInfo: @dataclass(frozen=True) -class OperationStatus: +class OperationStatus(BaseOperationStatus): done: bool error: OperationErrorInfo | None # TBD: google.rpc.Status response: Any | None = field(repr=False) @@ -74,43 +97,46 @@ def _from_proto(cls, *, proto: ProtoOperation) -> Self: metadata=proto.metadata ) - @property - def name(self) -> str: - if self.is_succeeded: - return 'success' - if self.is_failed: - return 'failed' - return 'runnning' - def __repr__(self) -> str: error_text = '' if self.is_failed: error_text = f', error={self.error}' - return f'{self.__class__.__name__}<{self.name}{error_text}>' + return f'{self.__class__.__name__}<{self.status_name}{error_text}>' -class OperationInterface(abc.ABC, Generic[AnyResultTypeT_co]): +OperationStatusTypeT = TypeVar('OperationStatusTypeT', bound=BaseOperationStatus) + + +class OperationInterface(abc.ABC, Generic[AnyResultTypeT_co, OperationStatusTypeT]): id: str + _default_poll_timeout: ClassVar[int] = 3600 + _default_poll_interval: ClassVar[float] = 10 + _custom_default_poll_timeout: int | None = None + _sdk: BaseSDK @abc.abstractmethod - async def _get_status(self, *, timeout: float = 60) -> OperationStatus: + async def _get_status(self, *, timeout: float = 60) -> OperationStatusTypeT: pass @abc.abstractmethod async def _get_result(self, *, timeout: float = 60) -> AnyResultTypeT_co: pass + @abc.abstractmethod + async def _cancel(self, *, timeout: float = 60) -> None: + pass + async def _sleep_impl(self, delay: float) -> None: # method is created for patching it in a tests await asyncio.sleep(delay) - async def _wait_impl(self, timeout: float, poll_interval: float) -> OperationStatus: + async def _wait_impl(self, timeout: float, poll_interval: float) -> OperationStatusTypeT: status = await self._get_status(timeout=timeout) while status.is_running: logger.debug( "%s have non-terminal status %s, sleep for %fs", - self, status.name, poll_interval + self, status.status_name, poll_interval ) await self._sleep_impl(poll_interval) status = await self._get_status(timeout=timeout) @@ -125,10 +151,16 @@ async def _wait_impl(self, timeout: float, poll_interval: float) -> OperationSta async def _wait( self, *, - timeout: float = 60, - poll_timeout: int = 3600, - poll_interval: float = 10, + timeout: float, + poll_timeout: int | None, + poll_interval: float | None, ) -> AnyResultTypeT_co: + # poll_timeout got from user + # custom_default_poll_timeout - from operation __init__ + # default_poll_timeout - from class + poll_timeout = poll_timeout or self._custom_default_poll_timeout or self._default_poll_timeout + poll_interval = poll_interval or self._default_poll_interval + logger.info( "Starting %s polling with a poll interval %fs and poll timeout %fs", self, poll_interval, poll_timeout, @@ -148,7 +180,7 @@ def __repr__(self) -> str: return f'{self.__class__.__name__}' -class BaseOperation(Generic[ResultTypeT_co], OperationInterface[ResultTypeT_co]): +class BaseOperation(Generic[ResultTypeT_co], OperationInterface[ResultTypeT_co, OperationStatus]): _last_known_status: OperationStatus | None def __init__( @@ -162,7 +194,7 @@ def __init__( initial_operation: ProtoOperation | None = None, service_name: str | None = None, transformer: None | Callable[[Any, float], Awaitable[ResultTypeT_co]] = None, - default_poll_timeout: int = 3600, + custom_default_poll_timeout: int = 3600, ): # pylint: disable=redefined-builtin self._id = id self._sdk = sdk @@ -171,7 +203,7 @@ def __init__( self._proto_metadata_type = proto_metadata_type self._service_name = service_name self._transformer = transformer or self._default_result_transofrmer - self._default_poll_timeout = default_poll_timeout + self._custom_default_poll_timeout = custom_default_poll_timeout self._last_known_status = None if initial_operation: @@ -282,7 +314,7 @@ async def _get_result(self, *, timeout: float = 60) -> ResultTypeT_co: f"{self} is done but response have result neither error fields set" ) - async def _cancel(self, *, timeout: float = 60) -> OperationStatus: + async def _cancel(self, *, timeout: float = 60) -> None: logger.debug('Cancelling %s', self) request = CancelOperationRequest(operation_id=self.id) async with self._client.get_service_stub( @@ -296,18 +328,16 @@ async def _cancel(self, *, timeout: float = 60) -> OperationStatus: timeout=timeout, expected_type=ProtoOperation, ) - self._last_known_status = status = OperationStatus._from_proto(proto=response) + self._last_known_status = OperationStatus._from_proto(proto=response) logger.info('%s successfully canceled', self) - return status async def _wait( self, *, timeout: float = 60, poll_timeout: int | None = None, - poll_interval: float = 10, + poll_interval: float | None = None, ) -> ResultTypeT_co: - poll_timeout = poll_timeout or self._default_poll_timeout return await super()._wait( timeout=timeout, poll_interval=poll_interval, @@ -315,23 +345,23 @@ async def _wait( ) -class AsyncOperation(BaseOperation[ResultTypeT_co]): - async def get_status(self, *, timeout: float = 60) -> OperationStatus: +class AsyncOperationMixin(OperationInterface[AnyResultTypeT_co, OperationStatusTypeT]): + async def get_status(self, *, timeout: float = 60) -> OperationStatusTypeT: return await self._get_status(timeout=timeout) - async def get_result(self, *, timeout: float = 60) -> ResultTypeT_co: + async def get_result(self, *, timeout: float = 60) -> AnyResultTypeT_co: return await self._get_result(timeout=timeout) - async def cancel(self, *, timeout: float = 60) -> OperationStatus: - return await self._cancel(timeout=timeout) + async def cancel(self, *, timeout: float = 60) -> None: + await self._cancel(timeout=timeout) async def wait( self, *, timeout: float = 60, poll_timeout: int | None = None, - poll_interval: float = 10, - ) -> ResultTypeT_co: + poll_interval: float | None = None, + ) -> AnyResultTypeT_co: return await self._wait( timeout=timeout, poll_timeout=poll_timeout, @@ -342,35 +372,50 @@ def __await__(self): return self.wait().__await__() -class Operation(BaseOperation[ResultTypeT_co]): - __get_status = run_sync(BaseOperation._get_status) - __get_result = run_sync(BaseOperation._get_result) - __wait = run_sync(BaseOperation._wait) - __cancel = run_sync(BaseOperation._cancel) +class AsyncOperation(AsyncOperationMixin[ResultTypeT_co, OperationStatus], BaseOperation[ResultTypeT_co]): + pass - def get_status(self, *, timeout: float = 60) -> OperationStatus: - return self.__get_status(timeout=timeout) - def get_result(self, *, timeout: float = 60) -> ResultTypeT_co: - return self.__get_result(timeout=timeout) +class SyncOperationMixin(OperationInterface[AnyResultTypeT_co, OperationStatusTypeT]): + def get_status(self, *, timeout: float = 60) -> OperationStatusTypeT: + return run_sync_impl( + self._get_status(timeout=timeout), + self._sdk + ) + + def get_result(self, *, timeout: float = 60) -> AnyResultTypeT_co: + return run_sync_impl( + self._get_result(timeout=timeout), + self._sdk, + ) - def cancel(self, *, timeout: float = 60) -> OperationStatus: - return self.__cancel(timeout=timeout) + def cancel(self, *, timeout: float = 60) -> None: + run_sync_impl( + self._cancel(timeout=timeout), + self._sdk + ) def wait( self, *, timeout: float = 60, poll_timeout: int | None = None, - poll_interval: float = 10, - ) -> ResultTypeT_co: - return self.__wait( - timeout=timeout, - poll_timeout=poll_timeout, - poll_interval=poll_interval, + poll_interval: float | None = None, + ) -> AnyResultTypeT_co: + return run_sync_impl( + self._wait( + timeout=timeout, + poll_timeout=poll_timeout, + poll_interval=poll_interval, + ), + self._sdk ) +class Operation(SyncOperationMixin[ResultTypeT_co, OperationStatus], BaseOperation[ResultTypeT_co]): + pass + + OperationTypeT = TypeVar('OperationTypeT', bound=BaseOperation) # this is needed to be able to declare Generic[OperationTypeT] in a dataclasses diff --git a/src/yandex_cloud_ml_sdk/_utils/proto.py b/src/yandex_cloud_ml_sdk/_utils/proto.py index 9006601a..21f9695f 100644 --- a/src/yandex_cloud_ml_sdk/_utils/proto.py +++ b/src/yandex_cloud_ml_sdk/_utils/proto.py @@ -107,3 +107,10 @@ def _coerce(cls, value: str | int | ProtoEnumBase) -> Self: def _to_proto(self) -> int: assert hasattr(self, 'value') return self.value + + @classmethod + def _from_proto(cls, proto: int) -> Self: + try: + return cls(proto) + except ValueError: + return cls(-1)