Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
"elasticsearch>=7.17.12,<8.0.0", # 8.0 asks users to provide hosts or cloud_id when instantiating ElasticSearch(); 7.9.1 has legacy numpy.float_ which was fixed in https://github.com/elastic/elasticsearch-py/pull/2551.
"faiss-cpu>=1.8.0.post1", # Pins numpy < 2
"h5py",
"pylance",
"jax>=0.3.14; sys_platform != 'win32'",
"jaxlib>=0.3.14; sys_platform != 'win32'",
"lz4; python_version < '3.14'", # python 3.14 gives ImportError: cannot import name '_compression' from partially initialized module 'lz4.frame
Expand Down
16 changes: 9 additions & 7 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@
from .iterable_dataset import IterableDataset
from .naming import camelcase_to_snakecase, snakecase_to_camelcase
from .packaged_modules import (
_ALL_ALLOWED_EXTENSIONS,
_EXTENSION_TO_MODULE,
_MODULE_TO_EXTENSIONS,
_MODULE_TO_METADATA_EXTENSIONS,
_MODULE_TO_METADATA_FILE_NAMES,
_PACKAGED_DATASETS_MODULES,
)
Expand All @@ -91,8 +93,6 @@

logger = get_logger(__name__)

ALL_ALLOWED_EXTENSIONS = list(_EXTENSION_TO_MODULE.keys()) + [".zip"]


class _InitializeConfiguredDatasetBuilder:
"""
Expand Down Expand Up @@ -328,7 +328,7 @@ def create_builder_configs_from_metadata_configs(
)
config_data_files_dict = DataFilesPatternsDict.from_patterns(
config_patterns,
allowed_extensions=ALL_ALLOWED_EXTENSIONS,
allowed_extensions=_ALL_ALLOWED_EXTENSIONS,
)
except EmptyDatasetError as e:
raise EmptyDatasetError(
Expand Down Expand Up @@ -436,14 +436,15 @@ def get_module(self) -> DatasetModule:
data_files = DataFilesDict.from_patterns(
patterns,
base_path=base_path,
allowed_extensions=ALL_ALLOWED_EXTENSIONS,
allowed_extensions=_ALL_ALLOWED_EXTENSIONS,
)
module_name, default_builder_kwargs = infer_module_for_data_files(
data_files=data_files,
path=self.path,
)
data_files = data_files.filter(
extensions=_MODULE_TO_EXTENSIONS[module_name], file_names=_MODULE_TO_METADATA_FILE_NAMES[module_name]
extensions=_MODULE_TO_EXTENSIONS[module_name] + _MODULE_TO_METADATA_EXTENSIONS[module_name],
file_names=_MODULE_TO_METADATA_FILE_NAMES[module_name],
)
module_path, _ = _PACKAGED_DATASETS_MODULES[module_name]
if metadata_configs:
Expand Down Expand Up @@ -633,7 +634,7 @@ def get_module(self) -> DatasetModule:
data_files = DataFilesDict.from_patterns(
patterns,
base_path=base_path,
allowed_extensions=ALL_ALLOWED_EXTENSIONS,
allowed_extensions=_ALL_ALLOWED_EXTENSIONS,
download_config=self.download_config,
)
module_name, default_builder_kwargs = infer_module_for_data_files(
Expand All @@ -642,7 +643,8 @@ def get_module(self) -> DatasetModule:
download_config=self.download_config,
)
data_files = data_files.filter(
extensions=_MODULE_TO_EXTENSIONS[module_name], file_names=_MODULE_TO_METADATA_FILE_NAMES[module_name]
extensions=_MODULE_TO_EXTENSIONS[module_name] + _MODULE_TO_METADATA_EXTENSIONS[module_name],
file_names=_MODULE_TO_METADATA_FILE_NAMES[module_name],
)
module_path, _ = _PACKAGED_DATASETS_MODULES[module_name]
if metadata_configs:
Expand Down
14 changes: 14 additions & 0 deletions src/datasets/packaged_modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .hdf5 import hdf5
from .imagefolder import imagefolder
from .json import json
from .lance import lance
from .niftifolder import niftifolder
from .pandas import pandas
from .parquet import parquet
Expand Down Expand Up @@ -53,6 +54,7 @@ def _hash_python_lines(lines: list[str]) -> str:
"xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())),
"hdf5": (hdf5.__name__, _hash_python_lines(inspect.getsource(hdf5).splitlines())),
"eval": (eval.__name__, _hash_python_lines(inspect.getsource(eval).splitlines())),
"lance": (lance.__name__, _hash_python_lines(inspect.getsource(lance).splitlines())),
}

# get importable module names and hash for caching
Expand Down Expand Up @@ -85,6 +87,7 @@ def _hash_python_lines(lines: list[str]) -> str:
".hdf5": ("hdf5", {}),
".h5": ("hdf5", {}),
".eval": ("eval", {}),
".lance": ("lance", {}),
}
_EXTENSION_TO_MODULE.update({ext: ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})
_EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})
Expand Down Expand Up @@ -114,3 +117,14 @@ def _hash_python_lines(lines: list[str]) -> str:
_MODULE_TO_METADATA_FILE_NAMES["videofolder"] = imagefolder.ImageFolder.METADATA_FILENAMES
_MODULE_TO_METADATA_FILE_NAMES["pdffolder"] = imagefolder.ImageFolder.METADATA_FILENAMES
_MODULE_TO_METADATA_FILE_NAMES["niftifolder"] = imagefolder.ImageFolder.METADATA_FILENAMES

_MODULE_TO_METADATA_EXTENSIONS: Dict[str, List[str]] = {}
for _module in _MODULE_TO_EXTENSIONS:
_MODULE_TO_METADATA_EXTENSIONS[_module] = []
_MODULE_TO_METADATA_EXTENSIONS["lance"] = lance.Lance.METADATA_EXTENSIONS

# Total

_ALL_EXTENSIONS = list(_EXTENSION_TO_MODULE.keys()) + [".zip"]
_ALL_METADATA_EXTENSIONS = list({_ext for _exts in _MODULE_TO_METADATA_EXTENSIONS.values() for _ext in _exts})
_ALL_ALLOWED_EXTENSIONS = _ALL_EXTENSIONS + _ALL_METADATA_EXTENSIONS
Empty file.
163 changes: 163 additions & 0 deletions src/datasets/packaged_modules/lance/lance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import re
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional

import pyarrow as pa
from huggingface_hub import HfApi

import datasets
from datasets.builder import Key
from datasets.table import table_cast
from datasets.utils.file_utils import is_local_path


if TYPE_CHECKING:
import lance
import lance.file

logger = datasets.utils.logging.get_logger(__name__)


@dataclass
class LanceConfig(datasets.BuilderConfig):
"""
BuilderConfig for Lance format.

Args:
features: (`Features`, *optional*):
Cast the data to `features`.
columns: (`List[str]`, *optional*):
List of columns to load, the other ones are ignored.
batch_size: (`int`, *optional*):
Size of the RecordBatches to iterate on. Default to 256.
token: (`str`, *optional*):
Optional HF token to use to download datasets.
"""

features: Optional[datasets.Features] = None
columns: Optional[List[str]] = None
batch_size: Optional[int] = 256
token: Optional[str] = None


def resolve_dataset_uris(files: List[str]) -> Dict[str, List[str]]:
dataset_uris = set()
for file_path in files:
path = Path(file_path)
if path.parent.name in {"_transactions", "_indices", "_versions"}:
dataset_root = path.parent.parent
dataset_uris.add(str(dataset_root))
return list(dataset_uris)


def _fix_hf_uri(uri: str) -> str:
# replace the revision tag from hf uri
if "@" in uri:
matched = re.match(r"(hf://.+?)(@[0-9a-f]+)(/.*)", uri)
if matched:
uri = matched.group(1) + matched.group(3)
return uri


def _fix_local_version_file(uri: str) -> str:
# replace symlinks with real files for _version
if "/_versions/" in uri and is_local_path(uri):
path = Path(uri)
if path.is_symlink():
data = path.read_bytes()
path.unlink()
path.write_bytes(data)
return uri


class Lance(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = LanceConfig
METADATA_EXTENSIONS = [".idx", ".txn", ".manifest"]

def _info(self):
return datasets.DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
import lance
import lance.file

if not self.config.data_files:
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
if self.repo_id:
api = HfApi(**dl_manager.download_config.storage_options["hf"])
dataset_sha = api.dataset_info(self.repo_id).sha
if dataset_sha != self.hash:
raise NotImplementedError(
f"lance doesn't support loading other revisions than 'main' yet, but got {self.hash}"
)
data_files = dl_manager.download(self.config.data_files)

# TODO: remove once Lance supports HF links with revisions
data_files = {split: [_fix_hf_uri(file) for file in files] for split, files in data_files.items()}
# TODO: remove once Lance supports symlinks for _version files
data_files = {split: [_fix_local_version_file(file) for file in files] for split, files in data_files.items()}

splits = []
for split_name, files in data_files.items():
storage_options = dl_manager.download_config.storage_options.get(files[0].split("://", 0)[0] + "://")

lance_dataset_uris = resolve_dataset_uris(files)
if lance_dataset_uris:
fragments = [
frag
for uri in lance_dataset_uris
for frag in lance.dataset(uri, storage_options=storage_options).get_fragments()
]
if self.info.features is None:
pa_schema = fragments[0]._ds.schema
splits.append(
datasets.SplitGenerator(
name=split_name,
gen_kwargs={"fragments": fragments, "lance_files": None},
)
)
else:
lance_files = [lance.file.LanceFileReader(file, storage_options=storage_options) for file in files]
if self.info.features is None:
pa_schema = lance_files[0].metadata().schema
splits.append(
datasets.SplitGenerator(
name=split_name,
gen_kwargs={"fragments": None, "lance_files": lance_files},
)
)
if self.info.features is None:
if self.config.columns:
fields = [
pa_schema.field(name) for name in self.config.columns if pa_schema.get_field_index(name) != -1
]
pa_schema = pa.schema(fields)
self.info.features = datasets.Features.from_arrow_schema(pa_schema)

return splits

def _cast_table(self, pa_table: pa.Table) -> pa.Table:
if self.info.features is not None:
# more expensive cast to support nested features with keys in a different order
# allows str <-> int/float or str to Audio for example
pa_table = table_cast(pa_table, self.info.features.arrow_schema)
return pa_table

def _generate_tables(
self,
fragments: Optional[List["lance.LanceFragment"]],
lance_files: Optional[List["lance.file.LanceFileReader"]],
):
if fragments:
for frag_idx, fragment in enumerate(fragments):
for batch_idx, batch in enumerate(
fragment.to_batches(columns=self.config.columns, batch_size=self.config.batch_size)
):
table = pa.Table.from_batches([batch])
yield Key(frag_idx, batch_idx), self._cast_table(table)
else:
for file_idx, lance_file in enumerate(lance_files):
for batch_idx, batch in enumerate(lance_file.read_all(batch_size=self.config.batch_size).to_batches()):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we support columns pushdown here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added it at LanceFileReader() initialization, since the argument is not available in read_all()

Copy link
Contributor Author

@eddyxu eddyxu Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, how does it work with multiple data files within same fragment?

In Lance, one fragment can be 1 or more data files, where each data files cover a few columns. This is how we can add new features / column cheaply without rewriting the datasets, by adding new data files to existing fragment.

Maybe we can address it as follow up tasks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in that case it's a dataset no ? since it requires a manifest or something to tell what the fragments are made of

LanceFileReader() is only used for single files, i.e. that don't belong to a lance dataset directory or require manifest files

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Lets 🚢 !!

table = pa.Table.from_batches([batch])
yield Key(file_idx, batch_idx), self._cast_table(table)
106 changes: 106 additions & 0 deletions tests/packaged_modules/test_lance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import lance
import numpy as np
import pyarrow as pa
import pytest

from datasets import load_dataset


@pytest.fixture
def lance_dataset(tmp_path) -> str:
data = pa.table(
{
"id": pa.array([1, 2, 3, 4]),
"value": pa.array([10.0, 20.0, 30.0, 40.0]),
"text": pa.array(["a", "b", "c", "d"]),
"vector": pa.FixedSizeListArray.from_arrays(pa.array([0.1] * 16, pa.float32()), list_size=4),
}
)
dataset_path = tmp_path / "test_dataset.lance"
lance.write_dataset(data, dataset_path)
return str(dataset_path)


@pytest.fixture
def lance_hf_dataset(tmp_path) -> str:
data = pa.table(
{
"id": pa.array([1, 2, 3, 4]),
"value": pa.array([10.0, 20.0, 30.0, 40.0]),
"text": pa.array(["a", "b", "c", "d"]),
"vector": pa.FixedSizeListArray.from_arrays(pa.array([0.1] * 16, pa.float32()), list_size=4),
}
)
dataset_dir = tmp_path / "data" / "train.lance"
dataset_dir.parent.mkdir(parents=True, exist_ok=True)
lance.write_dataset(data, dataset_dir)
lance.write_dataset(data[:2], tmp_path / "data" / "test.lance")

with open(tmp_path / "README.md", "w") as f:
f.write("""---
size_categories:
- 1M<n<10M
source_datasets:
- lance_test
---
# Test Lance Dataset\n\n
# My Markdown is fancier\n
""")

return str(tmp_path)


def test_load_lance_dataset(lance_dataset):
dataset_dict = load_dataset(lance_dataset)
assert "train" in dataset_dict.keys()

dataset = dataset_dict["train"]
assert "id" in dataset.column_names
assert "value" in dataset.column_names
assert "text" in dataset.column_names
assert "vector" in dataset.column_names
ids = dataset["id"]
assert ids == [1, 2, 3, 4]


@pytest.mark.parametrize("streaming", [False, True])
def test_load_hf_dataset(lance_hf_dataset, streaming):
dataset_dict = load_dataset(lance_hf_dataset, columns=["id", "text"], streaming=streaming)
assert "train" in dataset_dict.keys()
assert "test" in dataset_dict.keys()
dataset = dataset_dict["train"]

assert "id" in dataset.column_names
assert "text" in dataset.column_names
assert "value" not in dataset.column_names
assert "vector" not in dataset.column_names
ids = list(dataset["id"])
assert ids == [1, 2, 3, 4]
text = list(dataset["text"])
assert text == ["a", "b", "c", "d"]
assert "value" not in dataset.column_names


def test_load_vectors(lance_hf_dataset):
dataset_dict = load_dataset(lance_hf_dataset, columns=["vector"])
assert "train" in dataset_dict.keys()
dataset = dataset_dict["train"]

assert "vector" in dataset.column_names
vectors = dataset.data["vector"].combine_chunks().values.to_numpy(zero_copy_only=False)
assert np.allclose(vectors, np.full(16, 0.1))


@pytest.mark.parametrize("streaming", [False, True])
def test_load_lance_streaming_modes(lance_hf_dataset, streaming):
"""Test loading Lance dataset in both streaming and non-streaming modes."""
from datasets import IterableDataset

ds = load_dataset(lance_hf_dataset, split="train", streaming=streaming)
if streaming:
assert isinstance(ds, IterableDataset)
items = list(ds)
else:
items = list(ds)
assert len(items) == 4
assert all("id" in item for item in items)
Loading