Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/yandex_cloud_ml_sdk/_datasets/draft.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def _validate_deferred(
self._transform_operation_result,
raise_on_validation_failure=raise_on_validation_failure,
),
default_poll_timeout=DEFAULT_OPERATION_POLL_TIMEOUT,
custom_default_poll_timeout=DEFAULT_OPERATION_POLL_TIMEOUT,
)

async def _upload_deferred(
Expand Down
63 changes: 14 additions & 49 deletions src/yandex_cloud_ml_sdk/_runs/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import dataclasses
from datetime import datetime
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, TypeVar, cast
from typing import TYPE_CHECKING, Any, AsyncIterator, ClassVar, Iterator, TypeVar, cast

from google.protobuf.wrappers_pb2 import Int64Value
from yandex.cloud.ai.assistants.v1.runs.run_pb2 import Run as ProtoRun
Expand All @@ -16,7 +16,7 @@
from yandex_cloud_ml_sdk._tools.tool_result import (
ProtoAssistantToolResultList, ToolResultInputType, tool_results_to_proto
)
from yandex_cloud_ml_sdk._types.operation import OperationInterface
from yandex_cloud_ml_sdk._types.operation import AsyncOperationMixin, OperationInterface, SyncOperationMixin
from yandex_cloud_ml_sdk._types.resource import BaseResource
from yandex_cloud_ml_sdk._types.result import ProtoMessage
from yandex_cloud_ml_sdk._types.schemas import ResponseType
Expand All @@ -31,7 +31,7 @@


@dataclasses.dataclass(frozen=True)
class BaseRun(BaseResource, OperationInterface[RunResult[ToolCallTypeT]]):
class BaseRun(BaseResource, OperationInterface[RunResult[ToolCallTypeT], RunStatus]):
id: str
assistant_id: str
thread_id: str
Expand All @@ -43,6 +43,9 @@ class BaseRun(BaseResource, OperationInterface[RunResult[ToolCallTypeT]]):
custom_prompt_truncation_options: PromptTruncationOptions | None
custom_response_format: ResponseType | None

_default_poll_timeout: ClassVar[int] = 300
_default_poll_interval: ClassVar[float] = 0.5

@property
def custom_max_prompt_tokens(self) -> int | None:
if self.custom_prompt_truncation_options:
Expand Down Expand Up @@ -167,14 +170,15 @@ async def requests() -> AsyncIterator[AttachRunRequest]:

return

async def _cancel(
self,
*,
timeout: float = 60
) -> None:
raise NotImplementedError("Run couldn't be cancelled")

class AsyncRun(BaseRun[AsyncToolCall]):
async def get_status(self, *, timeout: float = 60) -> RunStatus:
return await self._get_status(timeout=timeout)

async def get_result(self, *, timeout: float = 60) -> RunResult[AsyncToolCall]:
return await self._get_result(timeout=timeout)

class AsyncRun(AsyncOperationMixin[RunResult[AsyncToolCall], RunStatus], BaseRun[AsyncToolCall]):
async def listen(
self,
*,
Expand All @@ -189,22 +193,6 @@ async def listen(

__aiter__ = listen

async def wait(
self,
*,
timeout: float = 60,
poll_timeout: int = 300,
poll_interval: float = 0.5,
) -> RunResult[AsyncToolCall]:
return await self._wait(
timeout=timeout,
poll_timeout=poll_timeout,
poll_interval=poll_interval,
)

def __await__(self):
return self.wait().__await__()

async def submit_tool_results(
self,
tool_results: ToolResultInputType,
Expand All @@ -214,20 +202,11 @@ async def submit_tool_results(
await super()._submit_tool_results(tool_results=tool_results, timeout=timeout)


class Run(BaseRun[ToolCall]):
__get_status = run_sync(BaseRun._get_status)
__get_result = run_sync(BaseRun._get_result)
__wait = run_sync(BaseRun._wait)
class Run(SyncOperationMixin[RunResult[ToolCall], RunStatus], BaseRun[ToolCall]):
__listen = run_sync_generator(BaseRun._listen)
__iter__ = __listen
__submit_tool_results = run_sync(BaseRun._submit_tool_results)

def get_status(self, *, timeout: float = 60) -> RunStatus:
return self.__get_status(timeout=timeout)

def get_result(self, *, timeout: float = 60) -> RunResult[ToolCall]:
return self.__get_result(timeout=timeout)

def listen(
self,
*,
Expand All @@ -239,20 +218,6 @@ def listen(
timeout=timeout,
)

def wait(
self,
*,
timeout: float = 60,
poll_timeout: int = 300,
poll_interval: float = 0.5,
) -> RunResult[ToolCall]:
# NB: mypy can't unterstand normally __wait return type and thinks its ResultTypeT
return self.__wait( # type: ignore[return-value]
timeout=timeout,
poll_timeout=poll_timeout,
poll_interval=poll_interval,
)

def submit_tool_results(
self,
tool_results: ToolResultInputType,
Expand Down
29 changes: 8 additions & 21 deletions src/yandex_cloud_ml_sdk/_runs/status.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,21 @@
# pylint: disable=no-name-in-module
from __future__ import annotations

import enum
from enum import IntEnum

from yandex.cloud.ai.assistants.v1.runs.run_pb2 import RunState as ProtoRunState
from yandex.cloud.ai.assistants.v1.runs.run_service_pb2 import StreamEvent as ProtoStreamEvent

from yandex_cloud_ml_sdk._types.operation import BaseOperationStatus
from yandex_cloud_ml_sdk._utils.proto import ProtoEnumBase

class BaseRunStatus:
@property
def is_running(self) -> bool:
raise NotImplementedError()

@property
def is_succeeded(self) -> bool:
raise NotImplementedError()

@property
def is_failed(self) -> bool:
raise NotImplementedError()
# pylint: disable=abstract-method
class BaseRunStatus(BaseOperationStatus):
pass


class RunStatus(BaseRunStatus, int, enum.Enum):
class RunStatus(BaseRunStatus, ProtoEnumBase, IntEnum):
UNKNOWN = -1
RUN_STATUS_UNSPECIFIED = ProtoRunState.RUN_STATUS_UNSPECIFIED
PENDING = ProtoRunState.PENDING
Expand All @@ -42,15 +36,8 @@ def is_succeeded(self) -> bool:
def is_failed(self) -> bool:
return self is self.FAILED

@classmethod
def _from_proto(cls, proto: int) -> RunStatus:
try:
return cls(proto)
except ValueError:
return cls(-1)


class StreamEvent(BaseRunStatus, int, enum.Enum):
class StreamEvent(BaseRunStatus, ProtoEnumBase, IntEnum):
UNKNOWN = -1
EVENT_TYPE_UNSPECIFIED = ProtoStreamEvent.EVENT_TYPE_UNSPECIFIED
PARTIAL_MESSAGE = ProtoStreamEvent.PARTIAL_MESSAGE
Expand Down
68 changes: 12 additions & 56 deletions src/yandex_cloud_ml_sdk/_tuning/tuning_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from yandex.cloud.operation.operation_service_pb2_grpc import OperationServiceStub

from yandex_cloud_ml_sdk._logging import TRACE, get_logger
from yandex_cloud_ml_sdk._types.operation import OperationErrorInfo, OperationInterface, OperationStatus
from yandex_cloud_ml_sdk._types.operation import (
AsyncOperationMixin, OperationErrorInfo, OperationInterface, OperationStatus, SyncOperationMixin
)
from yandex_cloud_ml_sdk._types.resource import BaseResource
from yandex_cloud_ml_sdk._types.result import ProtoMessage
from yandex_cloud_ml_sdk._utils.sync import run_sync
Expand Down Expand Up @@ -89,7 +91,7 @@ def _from_tuning_info(cls, info: TuningTaskInfo) -> TuningTaskStatus:
)


class BaseTuningTask(OperationInterface[TuningResultTypeT_co]):
class BaseTuningTask(OperationInterface[TuningResultTypeT_co, TuningTaskStatus]):
_sdk: BaseSDK

def __init__(
Expand Down Expand Up @@ -310,73 +312,27 @@ async def _get_metrics_url(self, *, timeout: float = 60) -> str | None:
return None


class AsyncTuningTask(BaseTuningTask[TuningResultTypeT_co]):
class AsyncTuningTask(
AsyncOperationMixin[TuningResultTypeT_co, TuningTaskStatus],
BaseTuningTask[TuningResultTypeT_co]
):
async def get_task_info(self, *, timeout: float = 60) -> TuningTaskInfo | None:
return await self._get_task_info(timeout=timeout)

async def get_status(self, *, timeout: float = 60) -> TuningTaskStatus:
return await self._get_status(timeout=timeout)

async def get_result(self, *, timeout: float = 60) -> TuningResultTypeT_co:
return await self._get_result(timeout=timeout)

async def cancel(self, *, timeout: float = 60) -> None:
await self._cancel(timeout=timeout)

async def wait(
self,
*,
timeout: float = 60,
poll_timeout: int = 72 * 60 * 60,
poll_interval: float = 10,
) -> TuningResultTypeT_co:
return await self._wait(
timeout=timeout,
poll_timeout=poll_timeout,
poll_interval=poll_interval,
)

async def get_metrics_url(self, *, timeout: float = 60) -> str | None:
return await self._get_metrics_url(timeout=timeout)

def __await__(self):
return self.wait().__await__()


class TuningTask(BaseTuningTask[TuningResultTypeT_co]):
__get_status = run_sync(BaseTuningTask._get_status)
__get_result = run_sync(BaseTuningTask._get_result)
__wait = run_sync(BaseTuningTask._wait)
__cancel = run_sync(BaseTuningTask._cancel)
class TuningTask(
SyncOperationMixin[TuningResultTypeT_co, TuningTaskStatus],
BaseTuningTask[TuningResultTypeT_co]
):
__get_metrics_url = run_sync(BaseTuningTask._get_metrics_url)
__get_task_info = run_sync(BaseTuningTask._get_task_info)

def get_task_info(self, *, timeout: float = 60) -> TuningTaskInfo | None:
return self.__get_task_info(timeout=timeout)

def get_status(self, *, timeout: float = 60) -> TuningTaskStatus:
return self.__get_status(timeout=timeout)

def get_result(self, *, timeout: float = 60) -> TuningResultTypeT_co:
return self.__get_result(timeout=timeout)

def cancel(self, *, timeout: float = 60) -> None:
self.__cancel(timeout=timeout)

def wait(
self,
*,
timeout: float = 60,
poll_timeout: int = 72 * 60 * 60,
poll_interval: float = 10,
) -> TuningResultTypeT_co:
result = self.__wait(
timeout=timeout,
poll_timeout=poll_timeout,
poll_interval=poll_interval,
)
return cast(TuningResultTypeT_co, result)

def get_metrics_url(self, *, timeout: float = 60) -> str | None:
return self.__get_metrics_url(timeout=timeout)

Expand Down
2 changes: 1 addition & 1 deletion src/yandex_cloud_ml_sdk/_types/batch/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
transformer=self._result_transformer,
service_name='ai-foundation-models',
initial_operation=initial_operation,
default_poll_timeout=60 * 60 * 72, # 72h
custom_default_poll_timeout=60 * 60 * 72, # 72h
)

# NB: I don't want to make parent operation class Generic[MetadataTypeT] just to
Expand Down
Loading