Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/async/datasets/read.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python3

from __future__ import annotations

import asyncio
import pathlib

from yandex_cloud_ml_sdk import AsyncYCloudML

PATH = pathlib.Path(__file__)
NAME = f'example-{PATH.parent.name}-{PATH.name}'


def local_path(path: str) -> pathlib.Path:
return pathlib.Path(__file__).parent / path


async def main() -> None:
# This example needs to have pyarrow installed
import pyarrow # pylint: disable=import-outside-toplevel,unused-import

sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')
sdk.setup_default_logging()

# On how to upload and work with dataset drafts refer to upload.py example file
dataset_draft = sdk.datasets.draft_from_path(
task_type='TextToTextGeneration',
path=local_path('completions.jsonlines'),
upload_format='jsonlines',
name=NAME,
)
dataset = await dataset_draft.upload()
print(f'Going to read {dataset=} records')
async for record in dataset.read():
print(record)

async for dataset in sdk.datasets.list(name_pattern=NAME):
await dataset.delete()


if __name__ == '__main__':
asyncio.run(main())
41 changes: 41 additions & 0 deletions examples/sync/datasets/read.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python3

from __future__ import annotations

import pathlib

from yandex_cloud_ml_sdk import YCloudML

PATH = pathlib.Path(__file__)
NAME = f'example-{PATH.parent.name}-{PATH.name}'


def local_path(path: str) -> pathlib.Path:
return pathlib.Path(__file__).parent / path


def main() -> None:
# This example needs to have pyarrow installed
import pyarrow # pylint: disable=import-outside-toplevel,unused-import

sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64')
sdk.setup_default_logging()

# On how to upload and work with dataset drafts refer to upload.py example file
dataset_draft = sdk.datasets.draft_from_path(
task_type='TextToTextGeneration',
path=local_path('completions.jsonlines'),
upload_format='jsonlines',
name=NAME,
)
dataset = dataset_draft.upload()
print(f'Going to read {dataset=} records')
for record in dataset.read():
print(record)

for dataset in sdk.datasets.list(name_pattern=NAME):
dataset.delete()


if __name__ == '__main__':
main()
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dependencies = [
"httpx>=0.27,<1",
"typing-extensions>=4",
"aiofiles>=24.1.0",
"packaging>=24"
]

[project.optional-dependencies]
Expand All @@ -53,6 +54,9 @@ langchain = [
pydantic = [
"pydantic>2",
]
datasets = [
"pyarrow>=19"
]

[project.urls]
Documentation = "https://yandex.cloud/ru/docs/foundation-models/"
Expand Down
136 changes: 125 additions & 11 deletions src/yandex_cloud_ml_sdk/_datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
from __future__ import annotations

import asyncio
import contextlib
import dataclasses
import tempfile
from collections.abc import AsyncIterator, Iterator
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, TypeVar, AsyncIterator
from typing import TYPE_CHECKING, Any, Iterable, TypeVar, cast

import aiofiles
import httpx
Expand All @@ -14,16 +17,18 @@
from yandex.cloud.ai.dataset.v1.dataset_pb2 import ValidationError as ProtoValidationError
from yandex.cloud.ai.dataset.v1.dataset_service_pb2 import (
DeleteDatasetRequest, DeleteDatasetResponse, FinishMultipartUploadDraftRequest, FinishMultipartUploadDraftResponse,
GetUploadDraftUrlRequest, GetUploadDraftUrlResponse, StartMultipartUploadDraftRequest,
StartMultipartUploadDraftResponse, UpdateDatasetRequest, UpdateDatasetResponse, UploadedPartInfo,
GetDownloadUrlsRequest, GetDownloadUrlsResponse
GetDownloadUrlsRequest, GetDownloadUrlsResponse, GetUploadDraftUrlRequest, GetUploadDraftUrlResponse,
StartMultipartUploadDraftRequest, StartMultipartUploadDraftResponse, UpdateDatasetRequest, UpdateDatasetResponse,
UploadedPartInfo
)
from yandex.cloud.ai.dataset.v1.dataset_service_pb2_grpc import DatasetServiceStub

from yandex_cloud_ml_sdk._logging import get_logger
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr, get_defined_value, PathLike, coerce_path
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, PathLike, UndefinedOr, coerce_path, get_defined_value
from yandex_cloud_ml_sdk._types.resource import BaseDeleteableResource, safe_on_delete
from yandex_cloud_ml_sdk._utils.sync import run_sync
from yandex_cloud_ml_sdk._utils.packages import requires_package
from yandex_cloud_ml_sdk._utils.pyarrow import read_dataset_records
from yandex_cloud_ml_sdk._utils.sync import run_sync, run_sync_generator

from .status import DatasetStatus

Expand Down Expand Up @@ -161,14 +166,41 @@ async def _download(
return await asyncio.wait_for(self.__download_impl(
base_path=base_path,
exist_ok=exist_ok,
timeout=timeout,
), timeout)

async def _read(
self,
*,
timeout: float,
batch_size: UndefinedOr[int],
) -> AsyncIterator[dict[Any, Any]]:
batch_size_ = get_defined_value(batch_size, None)
urls = await self._get_download_urls(timeout=timeout)
async with self._client.httpx() as client:
for _, url in urls:
_, filename = tempfile.mkstemp()
path = Path(filename)
try:
await self.__download_file(
path=path,
url=url,
client=client,
timeout=timeout
)

async for record in read_dataset_records(filename, batch_size=batch_size_):
yield record
finally:
path.unlink(missing_ok=True)

async def __download_impl(
self,
base_path: Path,
exist_ok: bool,
timeout: float,
) -> tuple[Path, ...]:
urls = await self._get_download_urls()
urls = await self._get_download_urls(timeout=timeout)
async with self._client.httpx() as client:
coroutines = []
for key, url in urls:
Expand All @@ -177,7 +209,7 @@ async def __download_impl(
raise ValueError(f"{file_path} already exists")

coroutines.append(
self.__download_file(file_path, url, client),
self.__download_file(file_path, url, client, timeout=timeout),
)

await asyncio.gather(*coroutines)
Expand All @@ -186,21 +218,27 @@ async def __download_impl(

async def __download_file(
self,
path: Path,
path: Path | str,
url: str,
client: httpx.AsyncClient,
timeout: float,
) -> None:
async with aiofiles.open(path, "wb") as file:
async for chunk in self.__read_from_url(url, client):
logger.debug(
'Going to download file for dataset %s from url %s to %s',
self.id, url, file.name
)
async for chunk in self.__read_from_url(url, client, timeout=timeout):
await file.write(chunk)

async def __read_from_url(
self,
url: str,
client: httpx.AsyncClient,
timeout: float,
chunk_size: int = 1024 * 1024 * 8, # 8Mb
) -> AsyncIterator[bytes]:
resp = await client.get(url)
resp = await client.get(url, timeout=timeout)
resp.raise_for_status()
async for chunk in resp.aiter_bytes(chunk_size=chunk_size):
yield chunk
Expand Down Expand Up @@ -257,6 +295,8 @@ async def _get_download_urls(
expected_type=GetDownloadUrlsResponse,
)

logger.debug("Dataset %s returned next download urls: %r", self.id, result.download_urls)

return [
(r.key, r.url) for r in result.download_urls
]
Expand Down Expand Up @@ -361,12 +401,50 @@ async def download(
exist_ok=exist_ok,
)

@requires_package('pyarrow', '>=19', 'AsyncDataset.read')
async def read(
self,
*,
timeout: float = 60,
batch_size: UndefinedOr[int] = UNDEFINED,
) -> AsyncIterator[dict[Any, Any]]:
"""Reads the dataset from backend and yields it records one by one.

This method lazily loads records by chunks, minimizing memory usage for large datasets.
The iterator yields dictionaries where keys are field names and values are parsed data.

.. note::
This method creates temporary files in the system's default temporary directory
during operation. To control the location of temporary files, refer to Python's
:func:`tempfile.gettempdir` documentation. Temporary files are automatically
cleaned up after use.

:param timeout: Maximum time in seconds for both gRPC and HTTP operations.
Includes connection establishment, data transfer, and processing time.
Defaults to 60 seconds.
:type timeout: float
:param batch_size: Number of records to load to memory in one chunk.
When UNDEFINED (default), uses backend's optimal chunk size (typically
corresponds to distinct Parquet files storage layout).
:type batch_size: int or Undefined
:yields: Dictionary representing single record with field-value pairs
:rtype: AsyncIterator[dict[Any, Any]]

"""

async for record in self._read(
timeout=timeout,
batch_size=batch_size
):
yield record


class Dataset(BaseDataset):
__update = run_sync(BaseDataset._update)
__delete = run_sync(BaseDataset._delete)
__list_upload_formats = run_sync(BaseDataset._list_upload_formats)
__download = run_sync(BaseDataset._download)
__read = run_sync_generator(BaseDataset._read)

def update(
self,
Expand Down Expand Up @@ -410,5 +488,41 @@ def download(
exist_ok=exist_ok,
)

@requires_package('pyarrow', '>=19', 'Dataset.read')
def read(
self,
*,
timeout: float = 60,
batch_size: UndefinedOr[int] = UNDEFINED,
) -> Iterator[dict[Any, Any]]:
"""Reads the dataset from backend and yields it records one by one.

This method lazily loads records by chunks, minimizing memory usage for large datasets.
The iterator yields dictionaries where keys are field names and values are parsed data.

.. note::
This method creates temporary files in the system's default temporary directory
during operation. To control the location of temporary files, refer to Python's
:func:`tempfile.gettempdir` documentation. Temporary files are automatically
cleaned up after use.

:param timeout: Maximum time in seconds for both gRPC and HTTP operations.
Includes connection establishment, data transfer, and processing time.
Defaults to 60 seconds.
:type timeout: float
:param batch_size: Number of records to load to memory in one chunk.
When UNDEFINED (default), uses backend's optimal chunk size (typically
corresponds to distinct Parquet files storage layout).
:type batch_size: int or Undefined
:yields: Dictionary representing single record with field-value pairs
:rtype Iterator[dict[Any, Any]]

"""

yield from self.__read(
timeout=timeout,
batch_size=batch_size
)


DatasetTypeT = TypeVar('DatasetTypeT', bound=BaseDataset)
36 changes: 36 additions & 0 deletions src/yandex_cloud_ml_sdk/_utils/pyarrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

import asyncio
from collections.abc import AsyncIterator, Iterator
from typing import Any

RecordType = dict[Any, Any]


async def read_dataset_records(path: str, batch_size: int | None) -> AsyncIterator[RecordType]:
iterator = read_dataset_records_sync(path=path, batch_size=batch_size)

def get_next() -> RecordType | None:
try:
return next(iterator)
except StopIteration:
return None

while True:
item = await asyncio.to_thread(get_next)
if item is None:
return

yield item


def read_dataset_records_sync(path: str, batch_size: int | None) -> Iterator[RecordType]:
import pyarrow.dataset as pd # pylint: disable=import-outside-toplevel

# we need use kwargs method to preserve original default value
kwargs = {}
if batch_size is not None:
kwargs['batch_size'] = batch_size
dataset = pd.dataset(source=path, format='parquet')
for batch in dataset.to_batches(**kwargs): # type: ignore[arg-type]
yield from batch.to_pylist()
Loading