Skip to content

Commit 06b6e02

Browse files
eddyxulhoestq
andauthored
Add lance format support (#7913)
* add lance as supported format * add more test * remove debug prints * re * simplify * do not need configure file * push * claude code * pass tests * handle streaming uri * style * support {local, streaming} x {dataset, single files} * added columns pushdown to LanceFileReader --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 0feb65d commit 06b6e02

File tree

6 files changed

+296
-7
lines changed

6 files changed

+296
-7
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@
167167
"elasticsearch>=7.17.12,<8.0.0", # 8.0 asks users to provide hosts or cloud_id when instantiating ElasticSearch(); 7.9.1 has legacy numpy.float_ which was fixed in https://github.com/elastic/elasticsearch-py/pull/2551.
168168
"faiss-cpu>=1.8.0.post1", # Pins numpy < 2
169169
"h5py",
170+
"pylance",
170171
"jax>=0.3.14; sys_platform != 'win32'",
171172
"jaxlib>=0.3.14; sys_platform != 'win32'",
172173
"lz4; python_version < '3.14'", # python 3.14 gives ImportError: cannot import name '_compression' from partially initialized module 'lz4.frame

src/datasets/load.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,10 @@
6666
from .iterable_dataset import IterableDataset
6767
from .naming import camelcase_to_snakecase, snakecase_to_camelcase
6868
from .packaged_modules import (
69+
_ALL_ALLOWED_EXTENSIONS,
6970
_EXTENSION_TO_MODULE,
7071
_MODULE_TO_EXTENSIONS,
72+
_MODULE_TO_METADATA_EXTENSIONS,
7173
_MODULE_TO_METADATA_FILE_NAMES,
7274
_PACKAGED_DATASETS_MODULES,
7375
)
@@ -91,8 +93,6 @@
9193

9294
logger = get_logger(__name__)
9395

94-
ALL_ALLOWED_EXTENSIONS = list(_EXTENSION_TO_MODULE.keys()) + [".zip"]
95-
9696

9797
class _InitializeConfiguredDatasetBuilder:
9898
"""
@@ -328,7 +328,7 @@ def create_builder_configs_from_metadata_configs(
328328
)
329329
config_data_files_dict = DataFilesPatternsDict.from_patterns(
330330
config_patterns,
331-
allowed_extensions=ALL_ALLOWED_EXTENSIONS,
331+
allowed_extensions=_ALL_ALLOWED_EXTENSIONS,
332332
)
333333
except EmptyDatasetError as e:
334334
raise EmptyDatasetError(
@@ -436,14 +436,15 @@ def get_module(self) -> DatasetModule:
436436
data_files = DataFilesDict.from_patterns(
437437
patterns,
438438
base_path=base_path,
439-
allowed_extensions=ALL_ALLOWED_EXTENSIONS,
439+
allowed_extensions=_ALL_ALLOWED_EXTENSIONS,
440440
)
441441
module_name, default_builder_kwargs = infer_module_for_data_files(
442442
data_files=data_files,
443443
path=self.path,
444444
)
445445
data_files = data_files.filter(
446-
extensions=_MODULE_TO_EXTENSIONS[module_name], file_names=_MODULE_TO_METADATA_FILE_NAMES[module_name]
446+
extensions=_MODULE_TO_EXTENSIONS[module_name] + _MODULE_TO_METADATA_EXTENSIONS[module_name],
447+
file_names=_MODULE_TO_METADATA_FILE_NAMES[module_name],
447448
)
448449
module_path, _ = _PACKAGED_DATASETS_MODULES[module_name]
449450
if metadata_configs:
@@ -633,7 +634,7 @@ def get_module(self) -> DatasetModule:
633634
data_files = DataFilesDict.from_patterns(
634635
patterns,
635636
base_path=base_path,
636-
allowed_extensions=ALL_ALLOWED_EXTENSIONS,
637+
allowed_extensions=_ALL_ALLOWED_EXTENSIONS,
637638
download_config=self.download_config,
638639
)
639640
module_name, default_builder_kwargs = infer_module_for_data_files(
@@ -642,7 +643,8 @@ def get_module(self) -> DatasetModule:
642643
download_config=self.download_config,
643644
)
644645
data_files = data_files.filter(
645-
extensions=_MODULE_TO_EXTENSIONS[module_name], file_names=_MODULE_TO_METADATA_FILE_NAMES[module_name]
646+
extensions=_MODULE_TO_EXTENSIONS[module_name] + _MODULE_TO_METADATA_EXTENSIONS[module_name],
647+
file_names=_MODULE_TO_METADATA_FILE_NAMES[module_name],
646648
)
647649
module_path, _ = _PACKAGED_DATASETS_MODULES[module_name]
648650
if metadata_configs:

src/datasets/packaged_modules/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .hdf5 import hdf5
1313
from .imagefolder import imagefolder
1414
from .json import json
15+
from .lance import lance
1516
from .niftifolder import niftifolder
1617
from .pandas import pandas
1718
from .parquet import parquet
@@ -53,6 +54,7 @@ def _hash_python_lines(lines: list[str]) -> str:
5354
"xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())),
5455
"hdf5": (hdf5.__name__, _hash_python_lines(inspect.getsource(hdf5).splitlines())),
5556
"eval": (eval.__name__, _hash_python_lines(inspect.getsource(eval).splitlines())),
57+
"lance": (lance.__name__, _hash_python_lines(inspect.getsource(lance).splitlines())),
5658
}
5759

5860
# get importable module names and hash for caching
@@ -85,6 +87,7 @@ def _hash_python_lines(lines: list[str]) -> str:
8587
".hdf5": ("hdf5", {}),
8688
".h5": ("hdf5", {}),
8789
".eval": ("eval", {}),
90+
".lance": ("lance", {}),
8891
}
8992
_EXTENSION_TO_MODULE.update({ext: ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})
9093
_EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})
@@ -114,3 +117,14 @@ def _hash_python_lines(lines: list[str]) -> str:
114117
_MODULE_TO_METADATA_FILE_NAMES["videofolder"] = imagefolder.ImageFolder.METADATA_FILENAMES
115118
_MODULE_TO_METADATA_FILE_NAMES["pdffolder"] = imagefolder.ImageFolder.METADATA_FILENAMES
116119
_MODULE_TO_METADATA_FILE_NAMES["niftifolder"] = imagefolder.ImageFolder.METADATA_FILENAMES
120+
121+
_MODULE_TO_METADATA_EXTENSIONS: Dict[str, List[str]] = {}
122+
for _module in _MODULE_TO_EXTENSIONS:
123+
_MODULE_TO_METADATA_EXTENSIONS[_module] = []
124+
_MODULE_TO_METADATA_EXTENSIONS["lance"] = lance.Lance.METADATA_EXTENSIONS
125+
126+
# Total
127+
128+
_ALL_EXTENSIONS = list(_EXTENSION_TO_MODULE.keys()) + [".zip"]
129+
_ALL_METADATA_EXTENSIONS = list({_ext for _exts in _MODULE_TO_METADATA_EXTENSIONS.values() for _ext in _exts})
130+
_ALL_ALLOWED_EXTENSIONS = _ALL_EXTENSIONS + _ALL_METADATA_EXTENSIONS

src/datasets/packaged_modules/lance/__init__.py

Whitespace-only changes.
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import re
2+
from dataclasses import dataclass
3+
from pathlib import Path
4+
from typing import TYPE_CHECKING, Dict, List, Optional
5+
6+
import pyarrow as pa
7+
from huggingface_hub import HfApi
8+
9+
import datasets
10+
from datasets.builder import Key
11+
from datasets.table import table_cast
12+
from datasets.utils.file_utils import is_local_path
13+
14+
15+
if TYPE_CHECKING:
16+
import lance
17+
import lance.file
18+
19+
logger = datasets.utils.logging.get_logger(__name__)
20+
21+
22+
@dataclass
23+
class LanceConfig(datasets.BuilderConfig):
24+
"""
25+
BuilderConfig for Lance format.
26+
27+
Args:
28+
features: (`Features`, *optional*):
29+
Cast the data to `features`.
30+
columns: (`List[str]`, *optional*):
31+
List of columns to load, the other ones are ignored.
32+
batch_size: (`int`, *optional*):
33+
Size of the RecordBatches to iterate on. Default to 256.
34+
token: (`str`, *optional*):
35+
Optional HF token to use to download datasets.
36+
"""
37+
38+
features: Optional[datasets.Features] = None
39+
columns: Optional[List[str]] = None
40+
batch_size: Optional[int] = 256
41+
token: Optional[str] = None
42+
43+
44+
def resolve_dataset_uris(files: List[str]) -> Dict[str, List[str]]:
45+
dataset_uris = set()
46+
for file_path in files:
47+
path = Path(file_path)
48+
if path.parent.name in {"_transactions", "_indices", "_versions"}:
49+
dataset_root = path.parent.parent
50+
dataset_uris.add(str(dataset_root))
51+
return list(dataset_uris)
52+
53+
54+
def _fix_hf_uri(uri: str) -> str:
55+
# replace the revision tag from hf uri
56+
if "@" in uri:
57+
matched = re.match(r"(hf://.+?)(@[0-9a-f]+)(/.*)", uri)
58+
if matched:
59+
uri = matched.group(1) + matched.group(3)
60+
return uri
61+
62+
63+
def _fix_local_version_file(uri: str) -> str:
64+
# replace symlinks with real files for _version
65+
if "/_versions/" in uri and is_local_path(uri):
66+
path = Path(uri)
67+
if path.is_symlink():
68+
data = path.read_bytes()
69+
path.unlink()
70+
path.write_bytes(data)
71+
return uri
72+
73+
74+
class Lance(datasets.ArrowBasedBuilder):
75+
BUILDER_CONFIG_CLASS = LanceConfig
76+
METADATA_EXTENSIONS = [".idx", ".txn", ".manifest"]
77+
78+
def _info(self):
79+
return datasets.DatasetInfo(features=self.config.features)
80+
81+
def _split_generators(self, dl_manager):
82+
import lance
83+
import lance.file
84+
85+
if not self.config.data_files:
86+
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
87+
if self.repo_id:
88+
api = HfApi(**dl_manager.download_config.storage_options["hf"])
89+
dataset_sha = api.dataset_info(self.repo_id).sha
90+
if dataset_sha != self.hash:
91+
raise NotImplementedError(
92+
f"lance doesn't support loading other revisions than 'main' yet, but got {self.hash}"
93+
)
94+
data_files = dl_manager.download(self.config.data_files)
95+
96+
# TODO: remove once Lance supports HF links with revisions
97+
data_files = {split: [_fix_hf_uri(file) for file in files] for split, files in data_files.items()}
98+
# TODO: remove once Lance supports symlinks for _version files
99+
data_files = {split: [_fix_local_version_file(file) for file in files] for split, files in data_files.items()}
100+
101+
splits = []
102+
for split_name, files in data_files.items():
103+
storage_options = dl_manager.download_config.storage_options.get(files[0].split("://", 0)[0] + "://")
104+
105+
lance_dataset_uris = resolve_dataset_uris(files)
106+
if lance_dataset_uris:
107+
fragments = [
108+
frag
109+
for uri in lance_dataset_uris
110+
for frag in lance.dataset(uri, storage_options=storage_options).get_fragments()
111+
]
112+
if self.info.features is None:
113+
pa_schema = fragments[0]._ds.schema
114+
splits.append(
115+
datasets.SplitGenerator(
116+
name=split_name,
117+
gen_kwargs={"fragments": fragments, "lance_files": None},
118+
)
119+
)
120+
else:
121+
lance_files = [
122+
lance.file.LanceFileReader(file, storage_options=storage_options, columns=self.config.columns)
123+
for file in files
124+
]
125+
if self.info.features is None:
126+
pa_schema = lance_files[0].metadata().schema
127+
splits.append(
128+
datasets.SplitGenerator(
129+
name=split_name,
130+
gen_kwargs={"fragments": None, "lance_files": lance_files},
131+
)
132+
)
133+
if self.info.features is None:
134+
if self.config.columns:
135+
fields = [
136+
pa_schema.field(name) for name in self.config.columns if pa_schema.get_field_index(name) != -1
137+
]
138+
pa_schema = pa.schema(fields)
139+
self.info.features = datasets.Features.from_arrow_schema(pa_schema)
140+
141+
return splits
142+
143+
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
144+
if self.info.features is not None:
145+
# more expensive cast to support nested features with keys in a different order
146+
# allows str <-> int/float or str to Audio for example
147+
pa_table = table_cast(pa_table, self.info.features.arrow_schema)
148+
return pa_table
149+
150+
def _generate_tables(
151+
self,
152+
fragments: Optional[List["lance.LanceFragment"]],
153+
lance_files: Optional[List["lance.file.LanceFileReader"]],
154+
):
155+
if fragments:
156+
for frag_idx, fragment in enumerate(fragments):
157+
for batch_idx, batch in enumerate(
158+
fragment.to_batches(columns=self.config.columns, batch_size=self.config.batch_size)
159+
):
160+
table = pa.Table.from_batches([batch])
161+
yield Key(frag_idx, batch_idx), self._cast_table(table)
162+
else:
163+
for file_idx, lance_file in enumerate(lance_files):
164+
for batch_idx, batch in enumerate(lance_file.read_all(batch_size=self.config.batch_size).to_batches()):
165+
table = pa.Table.from_batches([batch])
166+
yield Key(file_idx, batch_idx), self._cast_table(table)
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import lance
2+
import numpy as np
3+
import pyarrow as pa
4+
import pytest
5+
6+
from datasets import load_dataset
7+
8+
9+
@pytest.fixture
10+
def lance_dataset(tmp_path) -> str:
11+
data = pa.table(
12+
{
13+
"id": pa.array([1, 2, 3, 4]),
14+
"value": pa.array([10.0, 20.0, 30.0, 40.0]),
15+
"text": pa.array(["a", "b", "c", "d"]),
16+
"vector": pa.FixedSizeListArray.from_arrays(pa.array([0.1] * 16, pa.float32()), list_size=4),
17+
}
18+
)
19+
dataset_path = tmp_path / "test_dataset.lance"
20+
lance.write_dataset(data, dataset_path)
21+
return str(dataset_path)
22+
23+
24+
@pytest.fixture
25+
def lance_hf_dataset(tmp_path) -> str:
26+
data = pa.table(
27+
{
28+
"id": pa.array([1, 2, 3, 4]),
29+
"value": pa.array([10.0, 20.0, 30.0, 40.0]),
30+
"text": pa.array(["a", "b", "c", "d"]),
31+
"vector": pa.FixedSizeListArray.from_arrays(pa.array([0.1] * 16, pa.float32()), list_size=4),
32+
}
33+
)
34+
dataset_dir = tmp_path / "data" / "train.lance"
35+
dataset_dir.parent.mkdir(parents=True, exist_ok=True)
36+
lance.write_dataset(data, dataset_dir)
37+
lance.write_dataset(data[:2], tmp_path / "data" / "test.lance")
38+
39+
with open(tmp_path / "README.md", "w") as f:
40+
f.write("""---
41+
size_categories:
42+
- 1M<n<10M
43+
source_datasets:
44+
- lance_test
45+
---
46+
# Test Lance Dataset\n\n
47+
# My Markdown is fancier\n
48+
""")
49+
50+
return str(tmp_path)
51+
52+
53+
def test_load_lance_dataset(lance_dataset):
54+
dataset_dict = load_dataset(lance_dataset)
55+
assert "train" in dataset_dict.keys()
56+
57+
dataset = dataset_dict["train"]
58+
assert "id" in dataset.column_names
59+
assert "value" in dataset.column_names
60+
assert "text" in dataset.column_names
61+
assert "vector" in dataset.column_names
62+
ids = dataset["id"]
63+
assert ids == [1, 2, 3, 4]
64+
65+
66+
@pytest.mark.parametrize("streaming", [False, True])
67+
def test_load_hf_dataset(lance_hf_dataset, streaming):
68+
dataset_dict = load_dataset(lance_hf_dataset, columns=["id", "text"], streaming=streaming)
69+
assert "train" in dataset_dict.keys()
70+
assert "test" in dataset_dict.keys()
71+
dataset = dataset_dict["train"]
72+
73+
assert "id" in dataset.column_names
74+
assert "text" in dataset.column_names
75+
assert "value" not in dataset.column_names
76+
assert "vector" not in dataset.column_names
77+
ids = list(dataset["id"])
78+
assert ids == [1, 2, 3, 4]
79+
text = list(dataset["text"])
80+
assert text == ["a", "b", "c", "d"]
81+
assert "value" not in dataset.column_names
82+
83+
84+
def test_load_vectors(lance_hf_dataset):
85+
dataset_dict = load_dataset(lance_hf_dataset, columns=["vector"])
86+
assert "train" in dataset_dict.keys()
87+
dataset = dataset_dict["train"]
88+
89+
assert "vector" in dataset.column_names
90+
vectors = dataset.data["vector"].combine_chunks().values.to_numpy(zero_copy_only=False)
91+
assert np.allclose(vectors, np.full(16, 0.1))
92+
93+
94+
@pytest.mark.parametrize("streaming", [False, True])
95+
def test_load_lance_streaming_modes(lance_hf_dataset, streaming):
96+
"""Test loading Lance dataset in both streaming and non-streaming modes."""
97+
from datasets import IterableDataset
98+
99+
ds = load_dataset(lance_hf_dataset, split="train", streaming=streaming)
100+
if streaming:
101+
assert isinstance(ds, IterableDataset)
102+
items = list(ds)
103+
else:
104+
items = list(ds)
105+
assert len(items) == 4
106+
assert all("id" in item for item in items)

0 commit comments

Comments
 (0)