Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 116 additions & 2 deletions src/yandex_cloud_ml_sdk/_datasets/dataset.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
# pylint: disable=no-name-in-module
from __future__ import annotations

import asyncio
import dataclasses
import os
import shutil
import tempfile
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, TypeVar

import aiofiles
import httpx
from typing_extensions import Self
from yandex.cloud.ai.dataset.v1.dataset_pb2 import DatasetInfo as ProtoDatasetInfo
from yandex.cloud.ai.dataset.v1.dataset_pb2 import ValidationError as ProtoValidationError
from yandex.cloud.ai.dataset.v1.dataset_service_pb2 import (
DeleteDatasetRequest, DeleteDatasetResponse, FinishMultipartUploadDraftRequest, FinishMultipartUploadDraftResponse,
GetUploadDraftUrlRequest, GetUploadDraftUrlResponse, StartMultipartUploadDraftRequest,
StartMultipartUploadDraftResponse, UpdateDatasetRequest, UpdateDatasetResponse, UploadedPartInfo
StartMultipartUploadDraftResponse, UpdateDatasetRequest, UpdateDatasetResponse, UploadedPartInfo,
GetDownloadUrlsRequest, GetDownloadUrlsResponse
)
from yandex.cloud.ai.dataset.v1.dataset_service_pb2_grpc import DatasetServiceStub

from yandex_cloud_ml_sdk._logging import get_logger
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr, get_defined_value
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr, get_defined_value, PathLike, is_defined, coerce_path
from yandex_cloud_ml_sdk._types.resource import BaseDeleteableResource, safe_on_delete
from yandex_cloud_ml_sdk._utils.sync import run_sync

Expand Down Expand Up @@ -136,6 +144,66 @@ async def _delete(

logger.info("Dataset %s successfully deleted", self.id)

@safe_on_delete
async def _download(
self,
*,
download_path: UndefinedOr[PathLike] = UNDEFINED,
timeout: float = 60,
) -> list[Path]:
logger.debug("Downloading dataset %s", self.id)

return await asyncio.wait_for(self.__download_impl(
download_path=download_path
), timeout)

async def __download_impl(
self,
*,
download_path: UndefinedOr[PathLike] = UNDEFINED
) -> list[Path]:
if not is_defined(download_path):
# Now using tmp dir. Maybe must be changed to global sdk param
base_path = Path(tempfile.gettempdir()) / "ycml" / "datasets" / self.id
if base_path.exists():
# If using temp dir, and it is not empty, removing it
logger.warning("Dataset %s already downloaded, removing dir %s", self.id, base_path)
shutil.rmtree(base_path)

os.makedirs(base_path, exist_ok=True)
else:
base_path = coerce_path(download_path)
if not base_path.exists():
raise ValueError(f"{base_path} does not exist")

if not base_path.is_dir():
raise ValueError(f"{base_path} is not a directory")

if os.listdir(base_path):
raise ValueError(f"{base_path} is not empty")

urls = await self._get_download_urls()
async with self._client.httpx() as client:
coroutines = [
self.__download_file(base_path / key, url, client) for key, url in urls
]

await asyncio.gather(*coroutines)

return [base_path / key for key, _ in urls]

async def __download_file(self, path: Path, url: str, client: httpx.AsyncClient) -> None:
# For now, assuming that dataset is relatively small and fits RAM
# In the future, downloading by parts must be added

async with aiofiles.open(path, "wb") as file:
resp = await client.get(url)
resp.raise_for_status()
await file.write(resp.read())
await file.flush()



async def _list_upload_formats(
self,
*,
Expand Down Expand Up @@ -169,6 +237,29 @@ async def _get_upload_url(
logger.info("Dataset %s upload url successfully fetched", self.id)
return result.upload_url

async def _get_download_urls(
self,
*,
timeout: float = 60,
) -> Iterable[tuple[str, str]]:
logger.debug("Fetching download urls for dataset %s", self.id)

request = GetDownloadUrlsRequest(
dataset_id=self.id,
)

async with self._client.get_service_stub(DatasetServiceStub, timeout=timeout) as stub:
result = await self._client.call_service(
stub.GetDownloadUrls,
request,
timeout=timeout,
expected_type=GetDownloadUrlsResponse,
)

return [
(r.key, r.url) for r in result.download_urls
]

async def _start_multipart_upload(
self,
*,
Expand Down Expand Up @@ -256,11 +347,23 @@ async def list_upload_formats(
) -> tuple[str, ...]:
return await self._list_upload_formats(timeout=timeout)

async def download(
self,
*,
download_path: UndefinedOr[PathLike] = UNDEFINED,
timeout: float = 60,
) -> list[Path]:
return await self._download(
download_path=download_path,
timeout=timeout,
)


class Dataset(BaseDataset):
__update = run_sync(BaseDataset._update)
__delete = run_sync(BaseDataset._delete)
__list_upload_formats = run_sync(BaseDataset._list_upload_formats)
__download = run_sync(BaseDataset._download)

def update(
self,
Expand Down Expand Up @@ -291,5 +394,16 @@ def list_upload_formats(
) -> tuple[str, ...]:
return self.__list_upload_formats(timeout=timeout)

def download(
self,
*,
download_path: UndefinedOr[PathLike] = UNDEFINED,
timeout: float = 60,
) -> list[Path]:
return self.__download(
download_path=download_path,
timeout=timeout,
)


DatasetTypeT = TypeVar('DatasetTypeT', bound=BaseDataset)
2 changes: 2 additions & 0 deletions test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pytest-flakes
pytest-mypy
pytest-pylint
pytest-recording
pytest-httpx
pytest-mock
tqdm
types-aiofiles
types-protobuf
Expand Down
Empty file added tests/datasets/__init__.py
Empty file.
191 changes: 191 additions & 0 deletions tests/datasets/test_download_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import tempfile
from pathlib import Path

import httpx
import pytest
from pytest_httpx import HTTPXMock
from yandex.cloud.ai.dataset.v1.dataset_pb2 import DatasetInfo

from yandex_cloud_ml_sdk._datasets.dataset import AsyncDataset


@pytest.fixture
def mock_dataset(mocker, tmp_path: Path) -> AsyncDataset:
"""Create a mock dataset for testing."""
sdk_mock = mocker.MagicMock()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, what the benefits of the python-mock against the unittest.mock.MagicMock?

Copy link
Copy Markdown
Collaborator Author

@ArtoLord ArtoLord Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python-mock is integrated with pytest and removes mocks after test method

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, you use fixture for this, it is also removes everething.
Also, instead of mocker.patch there is builtin monkeypatch in pytest

sdk_mock._client.httpx.return_value = httpx.AsyncClient()

dataset = AsyncDataset._from_proto(
sdk=sdk_mock,
proto=DatasetInfo(
dataset_id="id"
)
)

mocker.patch(
"tempfile.gettempdir",
return_value=str(tmp_path),
)

return dataset


@pytest.mark.asyncio
async def test_download_to_temp_dir(mock_dataset, httpx_mock: HTTPXMock, mocker):
"""Test downloading dataset to a temporary directory."""
mocker.patch.object(
mock_dataset, "_get_download_urls",
return_value=[("file1.txt", "https://example.com/file1.txt")]
)

# Mock the HTTP response
httpx_mock.add_response(
url="https://example.com/file1.txt",
content=b"test file content"
)

paths = await mock_dataset.download(timeout=30)

temp_dir = Path(tempfile.gettempdir()) / "ycml" / "datasets" / mock_dataset.id
assert temp_dir.exists()

assert paths == [temp_dir / "file1.txt"]
assert paths[0].read_bytes() == b"test file content"


@pytest.mark.asyncio
async def test_download_to_custom_dir(mock_dataset, tmp_path, httpx_mock: HTTPXMock, mocker):
"""Test downloading dataset to a custom directory."""
# Create empty directory
empty_dir = tmp_path / "empty"
empty_dir.mkdir()

mocker.patch.object(
mock_dataset, "_get_download_urls",
return_value=[("file1.txt", "https://example.com/file1.txt")]
)

# Mock the HTTP response
httpx_mock.add_response(
url="https://example.com/file1.txt",
content=b"test file content"
)

# Call download method with custom path
paths = await mock_dataset.download(download_path=empty_dir, timeout=30)
assert paths == [empty_dir / "file1.txt"]
assert paths[0].read_bytes() == b"test file content"


@pytest.mark.asyncio
async def test_download_multiple_files(httpx_mock: HTTPXMock, mock_dataset, tmp_path, mocker):
"""Test downloading multiple files from a dataset."""
# Create empty directory
empty_dir = tmp_path / "empty"
empty_dir.mkdir()

# Mock the _get_download_urls method
mocker.patch.object(
mock_dataset, "_get_download_urls",
return_value=[
("file1.txt", "https://example.com/file1.txt"),
("file2.txt", "https://example.com/file2.txt"),
]
)

# Mock the HTTP responses
httpx_mock.add_response(
url="https://example.com/file1.txt",
content=b"content of file 1"
)
httpx_mock.add_response(
url="https://example.com/file2.txt",
content=b"content of file 2"
)

# Call download method
result = await mock_dataset.download(download_path=empty_dir, timeout=30)

# Verify the result
paths = list(result)
assert len(paths) == 2
assert {p.name for p in paths} == {"file1.txt", "file2.txt"}
assert (empty_dir / "file1.txt").read_bytes() == b"content of file 1"
assert (empty_dir / "file2.txt").read_bytes() == b"content of file 2"


@pytest.mark.asyncio
async def test_download_to_non_existent_dir(mock_dataset, tmp_path, mocker):
"""Test downloading to a non-existent directory raises an error."""
non_existent_dir = tmp_path / "does_not_exist"

# Mock the _get_download_urls method
mocker.patch.object(
mock_dataset, "_get_download_urls",
return_value=[("file1.txt", "https://example.com/file1.txt")]
)

# Call download method with non-existent path
with pytest.raises(ValueError, match="does not exist"):
await mock_dataset.download(download_path=non_existent_dir, timeout=30)


@pytest.mark.asyncio
async def test_download_to_file_path(mock_dataset, tmp_path, mocker):
"""Test downloading to a file path raises an error."""
# Create the file
file_path = tmp_path / "file.txt"
file_path.touch()

# Mock the _get_download_urls method
mocker.patch.object(
mock_dataset, "_get_download_urls",
return_value=[("file1.txt", "https://example.com/file1.txt")]
)

# Call download method with file path
with pytest.raises(ValueError, match="is not a directory"):
await mock_dataset.download(download_path=file_path, timeout=30)


@pytest.mark.asyncio
async def test_download_to_non_empty_dir(mock_dataset, tmp_path, mocker):
"""Test downloading to a non-empty directory raises an error."""
# Create non-empty directory
non_empty_dir = tmp_path / "non_empty"
non_empty_dir.mkdir()
(non_empty_dir / "existing_file.txt").write_text("existing content")

# Mock the _get_download_urls method
mocker.patch.object(
mock_dataset, "_get_download_urls",
return_value=[("file1.txt", "https://example.com/file1.txt")]
)

# Call download method with non-empty directory
with pytest.raises(ValueError, match="is not empty"):
await mock_dataset.download(download_path=non_empty_dir, timeout=30)


@pytest.mark.asyncio
async def test_download_http_error(httpx_mock: HTTPXMock, mock_dataset, tmp_path, mocker):
"""Test handling HTTP errors during download."""
# Create empty directory
empty_dir = tmp_path / "empty"
empty_dir.mkdir()

# Mock the _get_download_urls method
mocker.patch.object(
mock_dataset, "_get_download_urls",
return_value=[("file1.txt", "https://example.com/file1.txt")]
)

# Mock HTTP error response
httpx_mock.add_response(
url="https://example.com/file1.txt",
status_code=404
)

# Call download method
with pytest.raises(httpx.HTTPStatusError):
await mock_dataset.download(download_path=empty_dir, timeout=30)
Loading