|
1 | 1 | # pylint: disable=no-name-in-module |
2 | 2 | from __future__ import annotations |
3 | 3 |
|
| 4 | +import asyncio |
4 | 5 | import dataclasses |
| 6 | +import os |
| 7 | +import shutil |
| 8 | +import tempfile |
5 | 9 | from datetime import datetime |
| 10 | +from pathlib import Path |
6 | 11 | from typing import TYPE_CHECKING, Any, Iterable, TypeVar |
7 | 12 |
|
| 13 | +import aiofiles |
| 14 | +import httpx |
8 | 15 | from typing_extensions import Self |
9 | 16 | from yandex.cloud.ai.dataset.v1.dataset_pb2 import DatasetInfo as ProtoDatasetInfo |
10 | 17 | from yandex.cloud.ai.dataset.v1.dataset_pb2 import ValidationError as ProtoValidationError |
11 | 18 | from yandex.cloud.ai.dataset.v1.dataset_service_pb2 import ( |
12 | 19 | DeleteDatasetRequest, DeleteDatasetResponse, FinishMultipartUploadDraftRequest, FinishMultipartUploadDraftResponse, |
13 | 20 | GetUploadDraftUrlRequest, GetUploadDraftUrlResponse, StartMultipartUploadDraftRequest, |
14 | | - StartMultipartUploadDraftResponse, UpdateDatasetRequest, UpdateDatasetResponse, UploadedPartInfo |
| 21 | + StartMultipartUploadDraftResponse, UpdateDatasetRequest, UpdateDatasetResponse, UploadedPartInfo, |
| 22 | + GetDownloadUrlsRequest, GetDownloadUrlsResponse |
15 | 23 | ) |
16 | 24 | from yandex.cloud.ai.dataset.v1.dataset_service_pb2_grpc import DatasetServiceStub |
17 | 25 |
|
18 | 26 | from yandex_cloud_ml_sdk._logging import get_logger |
19 | | -from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr, get_defined_value |
| 27 | +from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr, get_defined_value, PathLike, is_defined, coerce_path |
20 | 28 | from yandex_cloud_ml_sdk._types.resource import BaseDeleteableResource, safe_on_delete |
21 | 29 | from yandex_cloud_ml_sdk._utils.sync import run_sync |
22 | 30 |
|
@@ -136,6 +144,66 @@ async def _delete( |
136 | 144 |
|
137 | 145 | logger.info("Dataset %s successfully deleted", self.id) |
138 | 146 |
|
| 147 | + @safe_on_delete |
| 148 | + async def _download( |
| 149 | + self, |
| 150 | + *, |
| 151 | + download_path: UndefinedOr[PathLike] = UNDEFINED, |
| 152 | + timeout: float = 60, |
| 153 | + ) -> list[Path]: |
| 154 | + logger.debug("Downloading dataset %s", self.id) |
| 155 | + |
| 156 | + return await asyncio.wait_for(self.__download_impl( |
| 157 | + download_path=download_path |
| 158 | + ), timeout) |
| 159 | + |
| 160 | + async def __download_impl( |
| 161 | + self, |
| 162 | + *, |
| 163 | + download_path: UndefinedOr[PathLike] = UNDEFINED |
| 164 | + ) -> list[Path]: |
| 165 | + if not is_defined(download_path): |
| 166 | + # Now using tmp dir. Maybe must be changed to global sdk param |
| 167 | + base_path = Path(tempfile.gettempdir()) / "ycml" / "datasets" / self.id |
| 168 | + if base_path.exists(): |
| 169 | + # If using temp dir, and it is not empty, removing it |
| 170 | + logger.warning("Dataset %s already downloaded, removing dir %s", self.id, base_path) |
| 171 | + shutil.rmtree(base_path) |
| 172 | + |
| 173 | + os.makedirs(base_path, exist_ok=True) |
| 174 | + else: |
| 175 | + base_path = coerce_path(download_path) |
| 176 | + if not base_path.exists(): |
| 177 | + raise ValueError(f"{base_path} does not exist") |
| 178 | + |
| 179 | + if not base_path.is_dir(): |
| 180 | + raise ValueError(f"{base_path} is not a directory") |
| 181 | + |
| 182 | + if os.listdir(base_path): |
| 183 | + raise ValueError(f"{base_path} is not empty") |
| 184 | + |
| 185 | + urls = await self._get_download_urls() |
| 186 | + async with self._client.httpx() as client: |
| 187 | + coroutines = [ |
| 188 | + self.__download_file(base_path / key, url, client) for key, url in urls |
| 189 | + ] |
| 190 | + |
| 191 | + await asyncio.gather(*coroutines) |
| 192 | + |
| 193 | + return [base_path / key for key, _ in urls] |
| 194 | + |
| 195 | + async def __download_file(self, path: Path, url: str, client: httpx.AsyncClient) -> None: |
| 196 | + # For now, assuming that dataset is relatively small and fits RAM |
| 197 | + # In the future, downloading by parts must be added |
| 198 | + |
| 199 | + async with aiofiles.open(path, "wb") as file: |
| 200 | + resp = await client.get(url) |
| 201 | + resp.raise_for_status() |
| 202 | + await file.write(resp.read()) |
| 203 | + await file.flush() |
| 204 | + |
| 205 | + |
| 206 | + |
139 | 207 | async def _list_upload_formats( |
140 | 208 | self, |
141 | 209 | *, |
@@ -169,6 +237,29 @@ async def _get_upload_url( |
169 | 237 | logger.info("Dataset %s upload url successfully fetched", self.id) |
170 | 238 | return result.upload_url |
171 | 239 |
|
| 240 | + async def _get_download_urls( |
| 241 | + self, |
| 242 | + *, |
| 243 | + timeout: float = 60, |
| 244 | + ) -> Iterable[tuple[str, str]]: |
| 245 | + logger.debug("Fetching download urls for dataset %s", self.id) |
| 246 | + |
| 247 | + request = GetDownloadUrlsRequest( |
| 248 | + dataset_id=self.id, |
| 249 | + ) |
| 250 | + |
| 251 | + async with self._client.get_service_stub(DatasetServiceStub, timeout=timeout) as stub: |
| 252 | + result = await self._client.call_service( |
| 253 | + stub.GetDownloadUrls, |
| 254 | + request, |
| 255 | + timeout=timeout, |
| 256 | + expected_type=GetDownloadUrlsResponse, |
| 257 | + ) |
| 258 | + |
| 259 | + return [ |
| 260 | + (r.key, r.url) for r in result.download_urls |
| 261 | + ] |
| 262 | + |
172 | 263 | async def _start_multipart_upload( |
173 | 264 | self, |
174 | 265 | *, |
@@ -256,11 +347,23 @@ async def list_upload_formats( |
256 | 347 | ) -> tuple[str, ...]: |
257 | 348 | return await self._list_upload_formats(timeout=timeout) |
258 | 349 |
|
| 350 | + async def download( |
| 351 | + self, |
| 352 | + *, |
| 353 | + download_path: UndefinedOr[PathLike] = UNDEFINED, |
| 354 | + timeout: float = 60, |
| 355 | + ) -> list[Path]: |
| 356 | + return await self._download( |
| 357 | + download_path=download_path, |
| 358 | + timeout=timeout, |
| 359 | + ) |
| 360 | + |
259 | 361 |
|
260 | 362 | class Dataset(BaseDataset): |
261 | 363 | __update = run_sync(BaseDataset._update) |
262 | 364 | __delete = run_sync(BaseDataset._delete) |
263 | 365 | __list_upload_formats = run_sync(BaseDataset._list_upload_formats) |
| 366 | + __download = run_sync(BaseDataset._download) |
264 | 367 |
|
265 | 368 | def update( |
266 | 369 | self, |
@@ -291,5 +394,16 @@ def list_upload_formats( |
291 | 394 | ) -> tuple[str, ...]: |
292 | 395 | return self.__list_upload_formats(timeout=timeout) |
293 | 396 |
|
| 397 | + def download( |
| 398 | + self, |
| 399 | + *, |
| 400 | + download_path: UndefinedOr[PathLike] = UNDEFINED, |
| 401 | + timeout: float = 60, |
| 402 | + ) -> list[Path]: |
| 403 | + return self.__download( |
| 404 | + download_path=download_path, |
| 405 | + timeout=timeout, |
| 406 | + ) |
| 407 | + |
294 | 408 |
|
295 | 409 | DatasetTypeT = TypeVar('DatasetTypeT', bound=BaseDataset) |
0 commit comments