Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion src/yandex_cloud_ml_sdk/_batch/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@
from yandex_cloud_ml_sdk._types.batch.task_info import BatchTaskInfo
from yandex_cloud_ml_sdk._types.domain import BaseDomain
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr, get_defined_value
from yandex_cloud_ml_sdk._utils.doc import doc_from
from yandex_cloud_ml_sdk._utils.proto import ProtoEnumCoercible
from yandex_cloud_ml_sdk._utils.sync import run_sync, run_sync_generator

logger = get_logger(__name__)


class BaseBatch(BaseDomain, Generic[BatchTaskOperationTypeT]):
"""
Сlass for managing batch operations in Yandex Cloud ML SDK.

For usage examples see `batch example <https://github.com/yandex-cloud/yandex-cloud-ml-sdk/blob/master/examples/{link}/completions/batch.py>`_.
"""
_operation_impl: type[BatchTaskOperationTypeT]

async def _get(
Expand All @@ -32,6 +38,13 @@ async def _get(
*,
timeout: float = 60,
) -> BatchTaskOperationTypeT:
"""
Get batch task operation by ID or by BatchTaskInfo object.

:param task: Either task ID string or BatchTaskInfo object.
:param timeout: The timeout, or the maximum time to wait for the request to complete in seconds.
Defaults to 60 seconds.
"""
logger.debug('Fetching batch task %s from server', task)

if isinstance(task, BatchTaskInfo):
Expand Down Expand Up @@ -59,6 +72,14 @@ async def _list_operations(
status: UndefinedOr[ProtoEnumCoercible[BatchTaskStatus]] = UNDEFINED,
timeout: float = 60,
) -> AsyncIterator[BatchTaskOperationTypeT]:
"""
List batch task operations with optional filtering.

:param page_size: Maximum number of tasks per page (optional).
:param status: Filter tasks by status (optional).
:param timeout: The timeout, or the maximum time to wait for the request to complete in seconds.
Defaults to 60 seconds.
"""
logger.debug('Fetching batch task list')

async for task_proto in self._list_impl(
Expand All @@ -78,6 +99,14 @@ async def _list_info(
status: UndefinedOr[ProtoEnumCoercible[BatchTaskStatus]] = UNDEFINED,
timeout: float = 60,
) -> AsyncIterator[BatchTaskInfo]:
"""
List batch task information with optional filtering.

:param page_size: Maximum number of tasks per page (optional).
:param status: Filter tasks by status (optional).
:param timeout: The timeout, or the maximum time to wait for the request to complete in seconds.
Defaults to 60 seconds.
"""
logger.debug('Fetching batch task list')

async for task_proto in self._list_impl(
Expand Down Expand Up @@ -138,9 +167,11 @@ async def _list_impl(
page_token = response.next_page_token


@doc_from(BaseBatch, link="async")
class AsyncBatch(BaseBatch[AsyncBatchTaskOperation]):
_operation_impl = AsyncBatchTaskOperation

@doc_from(BaseBatch._get)
async def get(
self,
task: str | BatchTaskInfo,
Expand All @@ -149,6 +180,7 @@ async def get(
) -> AsyncBatchTaskOperation:
return await self._get(task=task, timeout=timeout)

@doc_from(BaseBatch._list_operations)
async def list_operations(
self,
*,
Expand All @@ -163,6 +195,7 @@ async def list_operations(
):
yield task

@doc_from(BaseBatch._list_info)
async def list_info(
self,
*,
Expand All @@ -177,13 +210,14 @@ async def list_info(
):
yield task


@doc_from(BaseBatch, link="sync")
class Batch(BaseBatch[BatchTaskOperation]):
_operation_impl = BatchTaskOperation
__get = run_sync(BaseBatch._get)
__list_operations = run_sync_generator(BaseBatch._list_operations)
__list_info = run_sync_generator(BaseBatch._list_info)

@doc_from(BaseBatch._get)
def get(
self,
task: str | BatchTaskInfo,
Expand All @@ -192,6 +226,7 @@ def get(
) -> BatchTaskOperation:
return self.__get(task=task, timeout=timeout)

@doc_from(BaseBatch._list_operations)
def list_operations(
self,
*,
Expand All @@ -205,6 +240,7 @@ def list_operations(
timeout=timeout
)

@doc_from(BaseBatch._list_info)
def list_info(
self,
*,
Expand Down