Skip to content

Commit cedcdda

Browse files
committed
Add datasets.read method
1 parent 19cd365 commit cedcdda

File tree

5 files changed

+242
-11
lines changed

5 files changed

+242
-11
lines changed

examples/async/datasets/read.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
import pathlib
7+
8+
from yandex_cloud_ml_sdk import AsyncYCloudML
9+
10+
PATH = pathlib.Path(__file__)
11+
NAME = f'example-{PATH.parent.name}-{PATH.name}'
12+
13+
14+
def local_path(path: str) -> pathlib.Path:
15+
return pathlib.Path(__file__).parent / path
16+
17+
18+
async def main() -> None:
19+
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')
20+
sdk.setup_default_logging()
21+
22+
# On how to upload and work with dataset drafts refer to upload.py example file
23+
dataset_draft = sdk.datasets.draft_from_path(
24+
task_type='TextToTextGeneration',
25+
path=local_path('completions.jsonlines'),
26+
upload_format='jsonlines',
27+
name=NAME,
28+
)
29+
dataset = await dataset_draft.upload()
30+
print(f'new {dataset=}')
31+
async for record in dataset.read():
32+
print(record)
33+
34+
async for dataset in sdk.datasets.list(name_pattern=NAME):
35+
await dataset.delete()
36+
37+
38+
if __name__ == '__main__':
39+
asyncio.run(main())

examples/sync/datasets/read.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
import pathlib
6+
7+
from yandex_cloud_ml_sdk import YCloudML
8+
9+
PATH = pathlib.Path(__file__)
10+
NAME = f'example-{PATH.parent.name}-{PATH.name}'
11+
12+
13+
def local_path(path: str) -> pathlib.Path:
14+
return pathlib.Path(__file__).parent / path
15+
16+
17+
def main() -> None:
18+
sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64')
19+
sdk.setup_default_logging()
20+
21+
# On how to upload and work with dataset drafts refer to upload.py example file
22+
dataset_draft = sdk.datasets.draft_from_path(
23+
task_type='TextToTextGeneration',
24+
path=local_path('completions.jsonlines'),
25+
upload_format='jsonlines',
26+
name=NAME,
27+
)
28+
dataset = dataset_draft.upload()
29+
print(f'new {dataset=}')
30+
for record in dataset.read():
31+
print(record)
32+
33+
for dataset in sdk.datasets.list(name_pattern=NAME):
34+
dataset.delete()
35+
36+
37+
if __name__ == '__main__':
38+
main()

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ dependencies = [
3939
"httpx>=0.27,<1",
4040
"typing-extensions>=4",
4141
"aiofiles>=24.1.0",
42+
"packaging>=24"
4243
]
4344

4445
[project.optional-dependencies]
@@ -53,6 +54,9 @@ langchain = [
5354
pydantic = [
5455
"pydantic>2",
5556
]
57+
datasets = [
58+
"pyarrow>=19"
59+
]
5660

5761
[project.urls]
5862
Documentation = "https://yandex.cloud/ru/docs/foundation-models/"

src/yandex_cloud_ml_sdk/_datasets/dataset.py

Lines changed: 125 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
from __future__ import annotations
33

44
import asyncio
5+
import contextlib
56
import dataclasses
7+
import tempfile
8+
from collections.abc import AsyncIterator, Iterator
69
from datetime import datetime
710
from pathlib import Path
8-
from typing import TYPE_CHECKING, Any, Iterable, TypeVar, AsyncIterator
11+
from typing import TYPE_CHECKING, Any, Iterable, TypeVar, cast
912

1013
import aiofiles
1114
import httpx
@@ -14,16 +17,18 @@
1417
from yandex.cloud.ai.dataset.v1.dataset_pb2 import ValidationError as ProtoValidationError
1518
from yandex.cloud.ai.dataset.v1.dataset_service_pb2 import (
1619
DeleteDatasetRequest, DeleteDatasetResponse, FinishMultipartUploadDraftRequest, FinishMultipartUploadDraftResponse,
17-
GetUploadDraftUrlRequest, GetUploadDraftUrlResponse, StartMultipartUploadDraftRequest,
18-
StartMultipartUploadDraftResponse, UpdateDatasetRequest, UpdateDatasetResponse, UploadedPartInfo,
19-
GetDownloadUrlsRequest, GetDownloadUrlsResponse
20+
GetDownloadUrlsRequest, GetDownloadUrlsResponse, GetUploadDraftUrlRequest, GetUploadDraftUrlResponse,
21+
StartMultipartUploadDraftRequest, StartMultipartUploadDraftResponse, UpdateDatasetRequest, UpdateDatasetResponse,
22+
UploadedPartInfo
2023
)
2124
from yandex.cloud.ai.dataset.v1.dataset_service_pb2_grpc import DatasetServiceStub
2225

2326
from yandex_cloud_ml_sdk._logging import get_logger
24-
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr, get_defined_value, PathLike, coerce_path
27+
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, PathLike, UndefinedOr, coerce_path, get_defined_value
2528
from yandex_cloud_ml_sdk._types.resource import BaseDeleteableResource, safe_on_delete
26-
from yandex_cloud_ml_sdk._utils.sync import run_sync
29+
from yandex_cloud_ml_sdk._utils.packages import requires_package
30+
from yandex_cloud_ml_sdk._utils.pyarrow import read_dataset_records
31+
from yandex_cloud_ml_sdk._utils.sync import run_sync, run_sync_generator
2732

2833
from .status import DatasetStatus
2934

@@ -161,14 +166,41 @@ async def _download(
161166
return await asyncio.wait_for(self.__download_impl(
162167
base_path=base_path,
163168
exist_ok=exist_ok,
169+
timeout=timeout,
164170
), timeout)
165171

172+
async def _read(
173+
self,
174+
*,
175+
timeout: float,
176+
batch_size: UndefinedOr[int],
177+
) -> AsyncIterator[dict[Any, Any]]:
178+
batch_size_ = get_defined_value(batch_size, None)
179+
urls = await self._get_download_urls(timeout=timeout)
180+
async with self._client.httpx() as client:
181+
for _, url in urls:
182+
_, filename = tempfile.mkstemp()
183+
path = Path(filename)
184+
try:
185+
await self.__download_file(
186+
path=path,
187+
url=url,
188+
client=client,
189+
timeout=timeout
190+
)
191+
192+
async for record in read_dataset_records(filename, batch_size=batch_size_):
193+
yield record
194+
finally:
195+
path.unlink(missing_ok=True)
196+
166197
async def __download_impl(
167198
self,
168199
base_path: Path,
169200
exist_ok: bool,
201+
timeout: float,
170202
) -> tuple[Path, ...]:
171-
urls = await self._get_download_urls()
203+
urls = await self._get_download_urls(timeout=timeout)
172204
async with self._client.httpx() as client:
173205
coroutines = []
174206
for key, url in urls:
@@ -177,7 +209,7 @@ async def __download_impl(
177209
raise ValueError(f"{file_path} already exists")
178210

179211
coroutines.append(
180-
self.__download_file(file_path, url, client),
212+
self.__download_file(file_path, url, client, timeout=timeout),
181213
)
182214

183215
await asyncio.gather(*coroutines)
@@ -186,21 +218,27 @@ async def __download_impl(
186218

187219
async def __download_file(
188220
self,
189-
path: Path,
221+
path: Path | str,
190222
url: str,
191223
client: httpx.AsyncClient,
224+
timeout: float,
192225
) -> None:
193226
async with aiofiles.open(path, "wb") as file:
194-
async for chunk in self.__read_from_url(url, client):
227+
logger.debug(
228+
'Going to download file for dataset %s from url %s to %s',
229+
self.id, url, file.name
230+
)
231+
async for chunk in self.__read_from_url(url, client, timeout=timeout):
195232
await file.write(chunk)
196233

197234
async def __read_from_url(
198235
self,
199236
url: str,
200237
client: httpx.AsyncClient,
238+
timeout: float,
201239
chunk_size: int = 1024 * 1024 * 8, # 8Mb
202240
) -> AsyncIterator[bytes]:
203-
resp = await client.get(url)
241+
resp = await client.get(url, timeout=timeout)
204242
resp.raise_for_status()
205243
async for chunk in resp.aiter_bytes(chunk_size=chunk_size):
206244
yield chunk
@@ -257,6 +295,8 @@ async def _get_download_urls(
257295
expected_type=GetDownloadUrlsResponse,
258296
)
259297

298+
logger.debug("Dataset %s returned next download urls: %r", self.id, result.download_urls)
299+
260300
return [
261301
(r.key, r.url) for r in result.download_urls
262302
]
@@ -361,12 +401,50 @@ async def download(
361401
exist_ok=exist_ok,
362402
)
363403

404+
@requires_package('pyarrow', '>=19', 'AsyncDataset.read')
405+
async def read(
406+
self,
407+
*,
408+
timeout: float = 60,
409+
batch_size: UndefinedOr[int] = UNDEFINED,
410+
) -> AsyncIterator[dict[Any, Any]]:
411+
"""Reads the dataset from backend and yields it records one by one.
412+
413+
This method lazily loads records by chunks, minimizing memory usage for large datasets.
414+
The iterator yields dictionaries where keys are field names and values are parsed data.
415+
416+
.. note::
417+
This method creates temporary files in the system's default temporary directory
418+
during operation. To control the location of temporary files, refer to Python's
419+
:func:`tempfile.gettempdir` documentation. Temporary files are automatically
420+
cleaned up after use.
421+
422+
:param timeout: Maximum time in seconds for both gRPC and HTTP operations.
423+
Includes connection establishment, data transfer, and processing time.
424+
Defaults to 60 seconds.
425+
:type timeout: float
426+
:param batch_size: Number of records to load to memory in one chunk.
427+
When UNDEFINED (default), uses backend's optimal chunk size (typically
428+
corresponds to distinct Parquet files storage layout).
429+
:type batch_size: int or Undefined
430+
:yields: Dictionary representing single record with field-value pairs
431+
:rtype: AsyncIterator[dict[Any, Any]]
432+
433+
"""
434+
435+
async for record in self._read(
436+
timeout=timeout,
437+
batch_size=batch_size
438+
):
439+
yield record
440+
364441

365442
class Dataset(BaseDataset):
366443
__update = run_sync(BaseDataset._update)
367444
__delete = run_sync(BaseDataset._delete)
368445
__list_upload_formats = run_sync(BaseDataset._list_upload_formats)
369446
__download = run_sync(BaseDataset._download)
447+
__read = run_sync_generator(BaseDataset._read)
370448

371449
def update(
372450
self,
@@ -410,5 +488,41 @@ def download(
410488
exist_ok=exist_ok,
411489
)
412490

491+
@requires_package('pyarrow', '>=19', 'Dataset.read')
492+
def read(
493+
self,
494+
*,
495+
timeout: float = 60,
496+
batch_size: UndefinedOr[int] = UNDEFINED,
497+
) -> Iterator[dict[Any, Any]]:
498+
"""Reads the dataset from backend and yields it records one by one.
499+
500+
This method lazily loads records by chunks, minimizing memory usage for large datasets.
501+
The iterator yields dictionaries where keys are field names and values are parsed data.
502+
503+
.. note::
504+
This method creates temporary files in the system's default temporary directory
505+
during operation. To control the location of temporary files, refer to Python's
506+
:func:`tempfile.gettempdir` documentation. Temporary files are automatically
507+
cleaned up after use.
508+
509+
:param timeout: Maximum time in seconds for both gRPC and HTTP operations.
510+
Includes connection establishment, data transfer, and processing time.
511+
Defaults to 60 seconds.
512+
:type timeout: float
513+
:param batch_size: Number of records to load to memory in one chunk.
514+
When UNDEFINED (default), uses backend's optimal chunk size (typically
515+
corresponds to distinct Parquet files storage layout).
516+
:type batch_size: int or Undefined
517+
:yields: Dictionary representing single record with field-value pairs
518+
:rtype Iterator[dict[Any, Any]]
519+
520+
"""
521+
522+
yield from self.__read(
523+
timeout=timeout,
524+
batch_size=batch_size
525+
)
526+
413527

414528
DatasetTypeT = TypeVar('DatasetTypeT', bound=BaseDataset)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from collections.abc import AsyncIterator, Iterator
5+
from typing import Any
6+
7+
import pyarrow.dataset as pd
8+
9+
RecordType = dict[Any, Any]
10+
11+
12+
async def read_dataset_records(path: str, batch_size: int | None) -> AsyncIterator[RecordType]:
13+
iterator = read_dataset_records_sync(path=path, batch_size=batch_size)
14+
15+
def get_next() -> RecordType | None:
16+
try:
17+
return next(iterator)
18+
except StopIteration:
19+
return None
20+
21+
while True:
22+
item = await asyncio.to_thread(get_next)
23+
if item is None:
24+
return
25+
26+
yield item
27+
28+
29+
def read_dataset_records_sync(path: str, batch_size: int | None) -> Iterator[RecordType]:
30+
# we need use kwargs method to preserve original default value
31+
kwargs = {}
32+
if batch_size is not None:
33+
kwargs['batch_size'] = batch_size
34+
dataset = pd.dataset(source=path, format='parquet')
35+
for batch in dataset.to_batches(**kwargs): # type: ignore[arg-type]
36+
yield from batch.to_pylist()

0 commit comments

Comments
 (0)