Skip to content

Commit 52aea48

Browse files
committed
Add dataset downloading
1 parent 362a96e commit 52aea48

File tree

4 files changed

+166
-115
lines changed

4 files changed

+166
-115
lines changed

src/yandex_cloud_ml_sdk/_datasets/dataset.py

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,30 @@
11
# pylint: disable=no-name-in-module
22
from __future__ import annotations
33

4+
import asyncio
45
import dataclasses
6+
import os
7+
import shutil
8+
import tempfile
59
from datetime import datetime
10+
from pathlib import Path
611
from typing import TYPE_CHECKING, Any, Iterable, TypeVar
712

13+
import aiofiles
14+
import httpx
815
from typing_extensions import Self
916
from yandex.cloud.ai.dataset.v1.dataset_pb2 import DatasetInfo as ProtoDatasetInfo
1017
from yandex.cloud.ai.dataset.v1.dataset_pb2 import ValidationError as ProtoValidationError
1118
from yandex.cloud.ai.dataset.v1.dataset_service_pb2 import (
1219
DeleteDatasetRequest, DeleteDatasetResponse, FinishMultipartUploadDraftRequest, FinishMultipartUploadDraftResponse,
1320
GetUploadDraftUrlRequest, GetUploadDraftUrlResponse, StartMultipartUploadDraftRequest,
14-
StartMultipartUploadDraftResponse, UpdateDatasetRequest, UpdateDatasetResponse, UploadedPartInfo
21+
StartMultipartUploadDraftResponse, UpdateDatasetRequest, UpdateDatasetResponse, UploadedPartInfo,
22+
GetDownloadUrlsRequest, GetDownloadUrlsResponse
1523
)
1624
from yandex.cloud.ai.dataset.v1.dataset_service_pb2_grpc import DatasetServiceStub
1725

1826
from yandex_cloud_ml_sdk._logging import get_logger
19-
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr, get_defined_value
27+
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr, get_defined_value, PathLike, is_defined, coerce_path
2028
from yandex_cloud_ml_sdk._types.resource import BaseDeleteableResource, safe_on_delete
2129
from yandex_cloud_ml_sdk._utils.sync import run_sync
2230

@@ -136,6 +144,66 @@ async def _delete(
136144

137145
logger.info("Dataset %s successfully deleted", self.id)
138146

147+
@safe_on_delete
148+
async def _download(
149+
self,
150+
*,
151+
download_path: UndefinedOr[PathLike] = UNDEFINED,
152+
timeout: float = 60,
153+
) -> list[Path]:
154+
logger.debug("Downloading dataset %s", self.id)
155+
156+
return await asyncio.wait_for(self.__download_impl(
157+
download_path=download_path
158+
), timeout)
159+
160+
async def __download_impl(
161+
self,
162+
*,
163+
download_path: UndefinedOr[PathLike] = UNDEFINED
164+
) -> list[Path]:
165+
if not is_defined(download_path):
166+
# Now using tmp dir. Maybe must be changed to global sdk param
167+
base_path = Path(tempfile.gettempdir()) / "ycml" / "datasets" / self.id
168+
if base_path.exists():
169+
# If using temp dir, and it is not empty, removing it
170+
logger.warning("Dataset %s already downloaded, removing dir %s", self.id, base_path)
171+
shutil.rmtree(base_path)
172+
173+
os.makedirs(base_path, exist_ok=True)
174+
else:
175+
base_path = coerce_path(download_path)
176+
if not base_path.exists():
177+
raise ValueError(f"{base_path} does not exist")
178+
179+
if not base_path.is_dir():
180+
raise ValueError(f"{base_path} is not a directory")
181+
182+
if os.listdir(base_path):
183+
raise ValueError(f"{base_path} is not empty")
184+
185+
urls = await self._get_download_urls()
186+
async with self._client.httpx() as client:
187+
coroutines = [
188+
self.__download_file(base_path / key, url, client) for key, url in urls
189+
]
190+
191+
await asyncio.gather(*coroutines)
192+
193+
return [base_path / key for key, _ in urls]
194+
195+
async def __download_file(self, path: Path, url: str, client: httpx.AsyncClient) -> None:
196+
# For now, assuming that dataset is relatively small and fits RAM
197+
# In the future, downloading by parts must be added
198+
199+
async with aiofiles.open(path, "wb") as file:
200+
resp = await client.get(url)
201+
resp.raise_for_status()
202+
await file.write(resp.read())
203+
await file.flush()
204+
205+
206+
139207
async def _list_upload_formats(
140208
self,
141209
*,
@@ -169,6 +237,29 @@ async def _get_upload_url(
169237
logger.info("Dataset %s upload url successfully fetched", self.id)
170238
return result.upload_url
171239

240+
async def _get_download_urls(
241+
self,
242+
*,
243+
timeout: float = 60,
244+
) -> Iterable[tuple[str, str]]:
245+
logger.debug("Fetching download urls for dataset %s", self.id)
246+
247+
request = GetDownloadUrlsRequest(
248+
dataset_id=self.id,
249+
)
250+
251+
async with self._client.get_service_stub(DatasetServiceStub, timeout=timeout) as stub:
252+
result = await self._client.call_service(
253+
stub.GetDownloadUrls,
254+
request,
255+
timeout=timeout,
256+
expected_type=GetDownloadUrlsResponse,
257+
)
258+
259+
return [
260+
(r.key, r.url) for r in result.download_urls
261+
]
262+
172263
async def _start_multipart_upload(
173264
self,
174265
*,
@@ -256,11 +347,23 @@ async def list_upload_formats(
256347
) -> tuple[str, ...]:
257348
return await self._list_upload_formats(timeout=timeout)
258349

350+
async def download(
351+
self,
352+
*,
353+
download_path: UndefinedOr[PathLike] = UNDEFINED,
354+
timeout: float = 60,
355+
) -> list[Path]:
356+
return await self._download(
357+
download_path=download_path,
358+
timeout=timeout,
359+
)
360+
259361

260362
class Dataset(BaseDataset):
261363
__update = run_sync(BaseDataset._update)
262364
__delete = run_sync(BaseDataset._delete)
263365
__list_upload_formats = run_sync(BaseDataset._list_upload_formats)
366+
__download = run_sync(BaseDataset._download)
264367

265368
def update(
266369
self,
@@ -291,5 +394,16 @@ def list_upload_formats(
291394
) -> tuple[str, ...]:
292395
return self.__list_upload_formats(timeout=timeout)
293396

397+
def download(
398+
self,
399+
*,
400+
download_path: UndefinedOr[PathLike] = UNDEFINED,
401+
timeout: float = 60,
402+
) -> list[Path]:
403+
return self.__download(
404+
download_path=download_path,
405+
timeout=timeout,
406+
)
407+
294408

295409
DatasetTypeT = TypeVar('DatasetTypeT', bound=BaseDataset)

test_requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ pytest-flakes
1010
pytest-mypy
1111
pytest-pylint
1212
pytest-recording
13+
pytest-httpx
14+
pytest-mock
1315
tqdm
1416
types-aiofiles
1517
types-protobuf

tests/datasets/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)