Skip to content

Commit b3aed5b

Browse files
authored
fix reading order, fd number was limited (#114)
1 parent 6766d96 commit b3aed5b

File tree

3 files changed

+154
-11
lines changed

3 files changed

+154
-11
lines changed

src/yandex_cloud_ml_sdk/_datasets/dataset.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33

44
import asyncio
55
import dataclasses
6-
import os
76
import tempfile
87
from collections.abc import AsyncIterator, Iterator
98
from datetime import datetime
109
from pathlib import Path
11-
from typing import TYPE_CHECKING, Any, Iterable, TypeVar
10+
from typing import TYPE_CHECKING, Any, Final, Iterable, TypeVar
1211

1312
import aiofiles
1413
import httpx
@@ -38,7 +37,7 @@
3837
logger = get_logger(__name__)
3938

4039
DEFAULT_CHUNK_SIZE = 100 * 1024 ** 2
41-
40+
DEFAULT_MAX_PARALLEL_DOWNLOADS: Final[int] = 16 # maximum number of files open for writing during download
4241

4342
@dataclasses.dataclass(frozen=True)
4443
class ValidationErrorInfo:
@@ -153,6 +152,7 @@ async def _download(
153152
download_path: PathLike,
154153
timeout: float = 60,
155154
exist_ok: bool = False,
155+
max_parallel_downloads: int = DEFAULT_MAX_PARALLEL_DOWNLOADS
156156
) -> tuple[Path, ...]:
157157
logger.debug("Downloading dataset %s", self.id)
158158

@@ -167,6 +167,7 @@ async def _download(
167167
base_path=base_path,
168168
exist_ok=exist_ok,
169169
timeout=timeout,
170+
max_parallel_downloads=max_parallel_downloads,
170171
), timeout)
171172

172173
async def _read(
@@ -176,10 +177,24 @@ async def _read(
176177
batch_size: UndefinedOr[int],
177178
) -> AsyncIterator[dict[Any, Any]]:
178179
batch_size_ = get_defined_value(batch_size, None)
180+
179181
urls = await self._get_download_urls(timeout=timeout)
182+
183+
def key_comparator(item: tuple[str, str]) -> tuple[int, int | str]:
184+
key_, _ = item
185+
key_ = Path(key_).stem
186+
if key_.isdigit():
187+
return 0, int(key_)
188+
else:
189+
return 1, key_ # Non-numeric keys come after numeric keys
190+
191+
sorted_urls = sorted(urls, key=key_comparator)
192+
180193
async with self._client.httpx() as client:
181-
for _, url in urls:
182-
fd, filename = tempfile.mkstemp()
194+
for key, url in sorted_urls:
195+
with tempfile.NamedTemporaryFile(delete=False) as tmp:
196+
filename = tmp.name
197+
183198
path = Path(filename)
184199
try:
185200
await self.__download_file(
@@ -188,29 +203,37 @@ async def _read(
188203
client=client,
189204
timeout=timeout
190205
)
191-
192206
async for record in read_dataset_records(filename, batch_size=batch_size_):
193207
yield record
194208
finally:
195-
os.close(fd)
196-
path.unlink()
209+
if path.exists():
210+
path.unlink()
197211

198212
async def __download_impl(
199213
self,
200214
base_path: Path,
201215
exist_ok: bool,
202216
timeout: float,
217+
max_parallel_downloads: int = DEFAULT_MAX_PARALLEL_DOWNLOADS
203218
) -> tuple[Path, ...]:
204219
urls = await self._get_download_urls(timeout=timeout)
220+
205221
async with self._client.httpx() as client:
222+
223+
semaphore = asyncio.Semaphore(max_parallel_downloads)
224+
225+
async def limited_download(file_path, url) -> None:
226+
async with semaphore:
227+
await self.__download_file(file_path, url, client, timeout=timeout)
228+
206229
coroutines = []
207230
for key, url in urls:
208231
file_path = base_path / key
209232
if file_path.exists() and not exist_ok:
210233
raise ValueError(f"{file_path} already exists")
211234

212235
coroutines.append(
213-
self.__download_file(file_path, url, client, timeout=timeout),
236+
limited_download(file_path, url)
214237
)
215238

216239
await asyncio.gather(*coroutines)
@@ -395,11 +418,13 @@ async def download(
395418
download_path: PathLike,
396419
timeout: float = 60,
397420
exist_ok: bool = False,
421+
max_parallel_downloads: int = DEFAULT_MAX_PARALLEL_DOWNLOADS,
398422
) -> tuple[Path, ...]:
399423
return await self._download(
400424
download_path=download_path,
401425
timeout=timeout,
402426
exist_ok=exist_ok,
427+
max_parallel_downloads=max_parallel_downloads,
403428
)
404429

405430
@requires_package('pyarrow', '>=19', 'AsyncDataset.read')

tests/datasets/test_download_datasets.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
# pylint: disable=redefined-outer-name
33
from __future__ import annotations
44

5+
import contextlib
56
from pathlib import Path
67

8+
import aiofiles
79
import httpx
810
import pytest
911
from pytest_httpx import HTTPXMock
@@ -44,7 +46,7 @@ async def test_download_to_temp_dir(mock_dataset, httpx_mock: HTTPXMock, mocker,
4446

4547
paths = await mock_dataset.download(timeout=30, download_path=tmp_path)
4648

47-
assert paths == (tmp_path / "file1.txt", )
49+
assert paths == (tmp_path / "file1.txt",)
4850
assert paths[0].read_bytes() == b"test file content"
4951

5052

@@ -178,5 +180,56 @@ async def test_download_with_exist_ok(mock_dataset, httpx_mock: HTTPXMock, mocke
178180

179181
paths = await mock_dataset.download(timeout=30, download_path=tmp_path, exist_ok=False)
180182

181-
assert paths == (tmp_path / "file1.txt", )
183+
assert paths == (tmp_path / "file1.txt",)
182184
assert paths[0].read_bytes() == b"test file content"
185+
186+
187+
@pytest.mark.asyncio
188+
async def test_download_fd_num(mock_dataset, httpx_mock: HTTPXMock, mocker, tmp_path: Path):
189+
"""Test checks that the number of simultaneously open fd's <= max_fd_num"""
190+
max_open = 0
191+
cur_open = 0
192+
orig_aiofiles_open = aiofiles.open
193+
max_fd_num = 5
194+
fake_file_num = 10
195+
196+
@contextlib.asynccontextmanager
197+
async def fake_open(*args, **kwargs):
198+
nonlocal cur_open, max_open
199+
cur_open += 1
200+
max_open = max(max_open, cur_open)
201+
f = await orig_aiofiles_open(*args, **kwargs)
202+
try:
203+
yield f
204+
finally:
205+
cur_open -= 1
206+
207+
mocker.patch("aiofiles.open", fake_open)
208+
209+
non_empty_dir = tmp_path / "non_empty"
210+
non_empty_dir.mkdir()
211+
(non_empty_dir / "file1.txt").write_text("existing content")
212+
213+
mocker.patch.object(
214+
mock_dataset, "_get_download_urls",
215+
return_value=[
216+
(f"file{i}.txt", f"https://example.com/file{i}.txt") for i in range(fake_file_num)
217+
]
218+
)
219+
for i in range(fake_file_num):
220+
httpx_mock.add_response(
221+
url=f"https://example.com/file{i}.txt",
222+
content=f"test file{i} content".encode()
223+
)
224+
225+
paths = await mock_dataset.download(
226+
timeout=30,
227+
download_path=tmp_path,
228+
exist_ok=False,
229+
max_parallel_downloads=max_fd_num
230+
)
231+
232+
assert paths == tuple(tmp_path / f"file{i}.txt" for i in range(fake_file_num))
233+
assert paths[0].read_bytes() == f"test file{0} content".encode()
234+
235+
assert max_open <= max_fd_num

tests/datasets/test_read.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,42 @@
1+
# pylint: disable=no-name-in-module
2+
# pylint: disable=redefined-outer-name
13
from __future__ import annotations
24

5+
import io
36
import uuid
47
from pathlib import Path
58

9+
import httpx
610
import psutil
11+
import pyarrow as pa
12+
import pyarrow.parquet as pq
713
import pytest
14+
from pytest_httpx import HTTPXMock
15+
from yandex.cloud.ai.dataset.v1.dataset_pb2 import DatasetInfo
816

917
from yandex_cloud_ml_sdk import AsyncYCloudML
18+
from yandex_cloud_ml_sdk._datasets.dataset import AsyncDataset
1019

1120
pytestmark = [pytest.mark.asyncio, pytest.mark.require_env('pyarrow')]
1221

22+
23+
@pytest.fixture
24+
def mock_dataset(mocker) -> AsyncDataset:
25+
"""Create a mock dataset for testing."""
26+
sdk_mock = mocker.MagicMock()
27+
sdk_mock._client.httpx.return_value = httpx.AsyncClient()
28+
29+
dataset = AsyncDataset._from_proto(
30+
sdk=sdk_mock,
31+
proto=DatasetInfo(
32+
dataset_id="id"
33+
)
34+
)
35+
36+
return dataset
37+
38+
39+
1340
@pytest.mark.allow_grpc
1441
@pytest.mark.vcr
1542
async def test_simple_read(async_sdk: AsyncYCloudML, completions_jsonlines: Path) -> None:
@@ -35,3 +62,41 @@ async def test_simple_read(async_sdk: AsyncYCloudML, completions_jsonlines: Path
3562
assert 'response' in line
3663

3764
await dataset.delete()
65+
66+
def make_parquet_bytes(name: str) -> bytes:
67+
table = pa.table({"name": [name]})
68+
sink = io.BytesIO()
69+
pq.write_table(table, sink)
70+
return sink.getvalue()
71+
72+
@pytest.mark.asyncio
73+
async def test_reading_order(mock_dataset, httpx_mock: HTTPXMock, mocker, tmp_path: Path) -> None:
74+
"""Test checks files reading order """
75+
non_empty_dir = tmp_path / "non_empty"
76+
non_empty_dir.mkdir()
77+
78+
file_names = ["1.parquet", "3.parquet", "test.parquet", "4.parquet", "2.parquet"]
79+
80+
mocker.patch.object(
81+
mock_dataset, "_get_download_urls",
82+
return_value=[
83+
(non_empty_dir / fname, f"https://example.com/{fname}") for fname in file_names
84+
]
85+
)
86+
for fname in file_names:
87+
httpx_mock.add_response(
88+
url=f"https://example.com/{fname}",
89+
content=make_parquet_bytes(fname)
90+
)
91+
92+
process = psutil.Process()
93+
fd_num = process.num_fds()
94+
95+
data = [line async for line in mock_dataset.read()]
96+
97+
assert process.num_fds() == fd_num
98+
assert data == [{'name': '1.parquet'},
99+
{'name': '2.parquet'},
100+
{'name': '3.parquet'},
101+
{'name': '4.parquet'},
102+
{'name': 'test.parquet'}]

0 commit comments

Comments
 (0)