Skip to content

Commit 3aa0ed7

Browse files
committed
feat(libcommon): use async parquet_index.query() methods everywhere
1 parent 67fb0b4 commit 3aa0ed7

File tree

6 files changed

+64
-64
lines changed

6 files changed

+64
-64
lines changed

libs/libcommon/src/libcommon/parquet_utils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import os
44
from collections.abc import Iterable
55
from dataclasses import dataclass, field
6-
from functools import lru_cache
6+
from functools import lru_cache, partial
77
from pathlib import Path
88
from typing import Optional, TypedDict
99
from urllib.parse import unquote
1010

11+
import anyio
1112
import numpy as np
1213
import pyarrow as pa
1314
import pyarrow.compute as pc
@@ -526,7 +527,7 @@ def _init_viewer_index(
526527

527528
# note that this cache size is global for the class, not per instance
528529
@lru_cache(maxsize=1)
529-
def query(self, offset: int, length: int) -> tuple[pa.Table, list[str]]:
530+
async def query(self, offset: int, length: int) -> tuple[pa.Table, list[str]]:
530531
"""Query the parquet files
531532
532533
Note that this implementation will always read at least one row group, to get the list of columns and always
@@ -541,11 +542,11 @@ def query(self, offset: int, length: int) -> tuple[pa.Table, list[str]]:
541542
`list[str]`: List of truncated columns.
542543
"""
543544
if self._use_libviewer:
544-
return self.query_libviewer_index(offset=offset, length=length)
545+
return await self.query_libviewer_index(offset=offset, length=length)
545546
else:
546-
return self.query_parquet_index(offset=offset, length=length)
547+
return await self.query_parquet_index(offset=offset, length=length)
547548

548-
def query_parquet_index(self, offset: int, length: int) -> tuple[pa.Table, list[str]]:
549+
async def query_parquet_index(self, offset: int, length: int) -> tuple[pa.Table, list[str]]:
549550
"""Query the parquet files using ParquetIndexWithMetadata.
550551
551552
This is the old implementation without libviewer doing row-group pruning using pyarrow.
@@ -554,9 +555,11 @@ def query_parquet_index(self, offset: int, length: int) -> tuple[pa.Table, list[
554555
f"Query {type(self.parquet_index).__name__} for dataset={self.dataset}, config={self.config},"
555556
f" split={self.split}, offset={offset}, length={length}"
556557
)
557-
return self.parquet_index.query(offset=offset, length=length)
558+
# run_sync doesn't support keyword arguments, so use partial
559+
queryfn = partial(self.parquet_index.query, offset=offset, length=length)
560+
return await anyio.to_thread.run_sync(queryfn)
558561

559-
def query_libviewer_index(self, offset: int, length: int) -> tuple[pa.Table, list[str]]:
562+
async def query_libviewer_index(self, offset: int, length: int) -> tuple[pa.Table, list[str]]:
560563
"""Query the parquet files using libviewer.
561564
562565
This is the new implementation using libviewer doing row-group and page pruning.
@@ -574,7 +577,7 @@ def query_libviewer_index(self, offset: int, length: int) -> tuple[pa.Table, lis
574577
raise IndexError("Length must be non-negative")
575578

576579
try:
577-
batches, _files_to_index = self.viewer_index.sync_scan(
580+
batches, _files_to_index = await self.viewer_index.scan(
578581
offset=offset, limit=length, scan_size_limit=self.max_scan_size
579582
)
580583
except lv.DatasetError as e:

libs/libcommon/tests/test_parquet_utils.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -454,37 +454,39 @@ def test_indexer_get_rows_index_sharded_with_parquet_metadata(
454454
assert metadata_path.exists()
455455

456456

457-
def test_rows_index_query_with_parquet_metadata(
457+
@pytest.mark.asyncio
458+
async def test_rows_index_query_with_parquet_metadata(
458459
rows_index_with_parquet_metadata: RowsIndex, ds_sharded: Dataset
459460
) -> None:
460461
assert isinstance(rows_index_with_parquet_metadata.parquet_index, ParquetIndexWithMetadata)
461462
assert not hasattr(rows_index_with_parquet_metadata, "viewer_index")
462-
result, _ = rows_index_with_parquet_metadata.query_parquet_index(offset=1, length=3)
463+
result, _ = await rows_index_with_parquet_metadata.query_parquet_index(offset=1, length=3)
463464
assert result.to_pydict() == ds_sharded[1:4]
464465

465-
result, _ = rows_index_with_parquet_metadata.query_parquet_index(offset=1, length=-1)
466+
result, _ = await rows_index_with_parquet_metadata.query_parquet_index(offset=1, length=-1)
466467
assert result.to_pydict() == ds_sharded[:0]
467468

468-
result, _ = rows_index_with_parquet_metadata.query_parquet_index(offset=1, length=0)
469+
result, _ = await rows_index_with_parquet_metadata.query_parquet_index(offset=1, length=0)
469470
assert result.to_pydict() == ds_sharded[:0]
470471

471-
result, _ = rows_index_with_parquet_metadata.query_parquet_index(offset=999999, length=1)
472+
result, _ = await rows_index_with_parquet_metadata.query_parquet_index(offset=999999, length=1)
472473
assert result.to_pydict() == ds_sharded[:0]
473474

474-
result, _ = rows_index_with_parquet_metadata.query_parquet_index(offset=1, length=99999999)
475+
result, _ = await rows_index_with_parquet_metadata.query_parquet_index(offset=1, length=99999999)
475476
assert result.to_pydict() == ds_sharded[1:]
476477

477478
with pytest.raises(IndexError):
478-
rows_index_with_parquet_metadata.query_parquet_index(offset=-1, length=2)
479+
await rows_index_with_parquet_metadata.query_parquet_index(offset=-1, length=2)
479480

480481
# test that the other query() calls query_parquet_index() rather than query_libviewer_index()
481-
result, _ = rows_index_with_parquet_metadata.query(offset=1, length=3)
482+
result, _ = await rows_index_with_parquet_metadata.query(offset=1, length=3)
482483
assert result.to_pydict() == ds_sharded[1:4]
483484
with pytest.raises(AttributeError):
484-
rows_index_with_parquet_metadata.query_libviewer_index(offset=1, length=3)
485+
await rows_index_with_parquet_metadata.query_libviewer_index(offset=1, length=3)
485486

486487

487-
def test_rows_index_query_with_parquet_metadata_libviewer(
488+
@pytest.mark.asyncio
489+
async def test_rows_index_query_with_parquet_metadata_libviewer(
488490
ds_sharded: Dataset,
489491
ds_sharded_fs: AbstractFileSystem,
490492
dataset_sharded_with_config_parquet_metadata: dict[str, Any],
@@ -508,27 +510,28 @@ def test_rows_index_query_with_parquet_metadata_libviewer(
508510

509511
assert isinstance(rows_index_with_parquet_metadata.viewer_index, lv.Dataset)
510512
assert not hasattr(rows_index_with_parquet_metadata, "parquet_index")
511-
result, _truncated_cols = rows_index_with_parquet_metadata.query_libviewer_index(offset=1, length=3)
513+
result, _truncated_cols = await rows_index_with_parquet_metadata.query_libviewer_index(offset=1, length=3)
512514
assert result.to_pydict() == ds_sharded[1:4]
513-
result, _truncated_cols = rows_index_with_parquet_metadata.query_libviewer_index(offset=1, length=0)
515+
result, _truncated_cols = await rows_index_with_parquet_metadata.query_libviewer_index(offset=1, length=0)
514516
assert result.to_pydict() == ds_sharded[:0]
515-
result, _truncated_cols = rows_index_with_parquet_metadata.query_libviewer_index(offset=999999, length=1)
517+
result, _truncated_cols = await rows_index_with_parquet_metadata.query_libviewer_index(offset=999999, length=1)
516518
assert result.to_pydict() == ds_sharded[:0]
517-
result, _truncated_cols = rows_index_with_parquet_metadata.query_libviewer_index(offset=1, length=99999999)
519+
result, _truncated_cols = await rows_index_with_parquet_metadata.query_libviewer_index(offset=1, length=99999999)
518520
assert result.to_pydict() == ds_sharded[1:]
519521
with pytest.raises(IndexError):
520-
rows_index_with_parquet_metadata.query_libviewer_index(offset=0, length=-1)
522+
await rows_index_with_parquet_metadata.query_libviewer_index(offset=0, length=-1)
521523
with pytest.raises(IndexError):
522-
rows_index_with_parquet_metadata.query_libviewer_index(offset=-1, length=2)
524+
await rows_index_with_parquet_metadata.query_libviewer_index(offset=-1, length=2)
523525

524526
# test that the other query() calls query_libviewer_index() rather than query_parquet_index()
525-
result, _ = rows_index_with_parquet_metadata.query(offset=1, length=3)
527+
result, _ = await rows_index_with_parquet_metadata.query(offset=1, length=3)
526528
assert result.to_pydict() == ds_sharded[1:4]
527529
with pytest.raises(AttributeError):
528-
rows_index_with_parquet_metadata.query_parquet_index(offset=1, length=3)
530+
await rows_index_with_parquet_metadata.query_parquet_index(offset=1, length=3)
529531

530532

531-
def test_rows_index_query_with_too_big_rows(
533+
@pytest.mark.asyncio
534+
async def test_rows_index_query_with_too_big_rows(
532535
parquet_metadata_directory: StrPath,
533536
ds_sharded: Dataset,
534537
ds_sharded_fs: AbstractFileSystem,
@@ -546,7 +549,7 @@ def test_rows_index_query_with_too_big_rows(
546549
)
547550

548551
with pytest.raises(TooBigRows):
549-
index.query_parquet_index(offset=0, length=3)
552+
await index.query_parquet_index(offset=0, length=3)
550553

551554
with patch("libcommon.parquet_utils.libviewer_config", LibviewerConfig(enable_for_datasets=True)):
552555
index = RowsIndex(
@@ -563,10 +566,11 @@ def test_rows_index_query_with_too_big_rows(
563566

564567
# test the same with page pruning API
565568
with pytest.raises(TooBigRows):
566-
index.query_libviewer_index(offset=0, length=2)
569+
await index.query_libviewer_index(offset=0, length=2)
567570

568571

569-
def test_rows_index_query_with_empty_dataset(
572+
@pytest.mark.asyncio
573+
async def test_rows_index_query_with_empty_dataset(
570574
ds_empty: Dataset,
571575
ds_empty_fs: AbstractFileSystem,
572576
dataset_empty_with_config_parquet_metadata: dict[str, Any],
@@ -585,10 +589,10 @@ def test_rows_index_query_with_empty_dataset(
585589

586590
assert isinstance(index.parquet_index, ParquetIndexWithMetadata)
587591
assert not hasattr(index, "viewer_index")
588-
result, _ = index.query_parquet_index(offset=0, length=1)
592+
result, _ = await index.query_parquet_index(offset=0, length=1)
589593
assert result.to_pydict() == ds_empty[:0]
590594
with pytest.raises(IndexError):
591-
index.query_parquet_index(offset=-1, length=2)
595+
await index.query_parquet_index(offset=-1, length=2)
592596

593597
# test the same with page pruning API
594598
import libviewer as lv
@@ -608,13 +612,14 @@ def test_rows_index_query_with_empty_dataset(
608612

609613
assert isinstance(index.viewer_index, lv.Dataset)
610614
assert not hasattr(index, "parquet_index")
611-
result, _ = index.query_libviewer_index(offset=0, length=1)
615+
result, _ = await index.query_libviewer_index(offset=0, length=1)
612616
assert result.to_pydict() == ds_empty[:0]
613617
with pytest.raises(IndexError):
614-
index.query_libviewer_index(offset=-1, length=2)
618+
await index.query_libviewer_index(offset=-1, length=2)
615619

616620

617-
def test_indexer_schema_mistmatch_error(
621+
@pytest.mark.asyncio
622+
async def test_indexer_schema_mistmatch_error(
618623
ds_sharded_fs: AbstractFileSystem,
619624
ds_sharded_fs_with_different_schema: AbstractFileSystem,
620625
dataset_sharded_with_config_parquet_metadata: dict[str, Any],
@@ -632,7 +637,7 @@ def test_indexer_schema_mistmatch_error(
632637
max_arrow_data_in_memory=9999999999,
633638
)
634639
with pytest.raises(SchemaMismatchError):
635-
index.query_parquet_index(offset=0, length=3)
640+
await index.query_parquet_index(offset=0, length=3)
636641

637642

638643
@pytest.mark.parametrize(

libs/libviewer/libviewer/dataset.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
2+
import functools
23

4+
import anyio
35
from huggingface_hub import hf_hub_download, list_repo_files
46

57
from ._internal import PyDataset, PyDatasetError as DatasetError # noqa: F401
@@ -101,3 +103,15 @@ def from_cache(repo, metadata_store, revision=None, download=False):
101103
data_store="file://",
102104
metadata_store=metadata_store,
103105
)
106+
107+
def sync_scan(
108+
self, limit=None, offset=None, scan_size_limit=1 * 1024 * 1024 * 1024
109+
):
110+
fn = functools.partial(
111+
self.scan, limit=limit, offset=offset, scan_size_limit=scan_size_limit
112+
)
113+
return anyio.run(fn)
114+
115+
def sync_index(self, files=None):
116+
fn = functools.partial(self.index, files=files)
117+
return anyio.run(fn)

libs/libviewer/src/lib.rs

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -73,25 +73,6 @@ impl PyDataset {
7373
Ok(self.dataset.files.clone())
7474
}
7575

76-
#[pyo3(signature = (limit=None, offset=None, scan_size_limit=DEFAULT_SCAN_SIZE_LIMIT))]
77-
fn sync_scan(
78-
&self,
79-
py: Python<'_>,
80-
limit: Option<u64>,
81-
offset: Option<u64>,
82-
scan_size_limit: u64,
83-
) -> PyResult<(Vec<Py<PyAny>>, Vec<IndexedFile>)> {
84-
let rt = tokio::runtime::Runtime::new()?;
85-
let (record_batches, files_to_index) =
86-
rt.block_on(self.dataset.scan(limit, offset, scan_size_limit))?;
87-
let pyarrow_batches = record_batches
88-
.into_iter()
89-
.map(|batch| Ok(batch.into_pyarrow(py)?.unbind()))
90-
.collect::<PyResult<Vec<_>>>()?;
91-
92-
Ok((pyarrow_batches, files_to_index))
93-
}
94-
9576
#[pyo3(signature = (limit=None, offset=None, scan_size_limit=DEFAULT_SCAN_SIZE_LIMIT))]
9677
fn scan<'py>(
9778
&self,
@@ -117,13 +98,6 @@ impl PyDataset {
11798
})
11899
}
119100

120-
#[pyo3(signature = (files=None))]
121-
fn sync_index(&self, files: Option<Vec<IndexedFile>>) -> PyResult<Vec<IndexedFile>> {
122-
let rt = tokio::runtime::Runtime::new()?;
123-
let indexed_files = rt.block_on(self.dataset.index(files.as_deref()))?;
124-
Ok(indexed_files)
125-
}
126-
127101
#[pyo3(signature = (files=None))]
128102
fn index<'py>(
129103
&self,
@@ -143,6 +117,7 @@ impl PyDataset {
143117
fn dv(m: &Bound<'_, PyModule>) -> PyResult<()> {
144118
// Bridge the Rust log crate with the Python logging module
145119
// pyo3_log::init();
120+
env_logger::init();
146121

147122
m.add_class::<PyDataset>()?;
148123
m.add("PyDatasetError", m.py().get_type::<PyDatasetError>())?;

services/rows/src/rows/routes/rows.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ async def rows_endpoint(request: Request) -> Response:
103103
with StepProfiler(method="rows_endpoint", step="query the rows"):
104104
try:
105105
# Some datasets have very long binary data that we truncate
106-
pa_table, truncated_columns = rows_index.query(offset=offset, length=length)
106+
pa_table, truncated_columns = await rows_index.query(offset=offset, length=length)
107107
except TooBigRows as err:
108108
raise TooBigContentError(str(err)) from None
109109
with StepProfiler(method="rows_endpoint", step="transform to a list"):

services/worker/src/worker/job_runners/split/first_rows.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
# Copyright 2022 The HuggingFace Authors.
33

44

5+
import functools
56
import logging
67
from pathlib import Path
78
from typing import Optional
89

10+
import anyio
911
from datasets import IterableDataset, get_dataset_config_info, load_dataset
1012
from fsspec.implementations.http import HTTPFileSystem
1113
from libcommon.constants import MAX_NUM_ROWS_PER_PAGE
@@ -117,7 +119,8 @@ def compute_first_rows_from_parquet_response(
117119
def get_rows_content(rows_max_number: int) -> RowsContent:
118120
try:
119121
# Some datasets have very long binary data that we truncate
120-
pa_table, truncated_columns = rows_index.query(offset=0, length=rows_max_number)
122+
queryfn = functools.partial(rows_index.query, offset=0, length=rows_max_number)
123+
pa_table, truncated_columns = anyio.run(queryfn)
121124
return RowsContent(
122125
rows=pa_table.to_pylist(),
123126
all_fetched=rows_index.num_rows_total <= rows_max_number,

0 commit comments

Comments
 (0)