Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@
)


def _normalize_path_or_paths(path_or_paths: Union[PathLike, Sequence_[PathLike]]) -> Union[PathLike, list[PathLike]]:
if isinstance(path_or_paths, Sequence_) and not isinstance(path_or_paths, (str, bytes, os.PathLike)):
return list(path_or_paths)
return path_or_paths


class DatasetInfoMixin:
"""This base class exposes some attributes of DatasetInfo
at the base level of the Dataset for easy access.
Expand Down Expand Up @@ -1276,7 +1282,7 @@ def from_list(

@staticmethod
def from_csv(
path_or_paths: Union[PathLike, list[PathLike]],
path_or_paths: Union[PathLike, Sequence_[PathLike]],
split: Optional[NamedSplit] = None,
features: Optional[Features] = None,
cache_dir: str = None,
Expand Down Expand Up @@ -1320,7 +1326,7 @@ def from_csv(
from .io.csv import CsvDatasetReader

return CsvDatasetReader(
path_or_paths,
_normalize_path_or_paths(path_or_paths),
split=split,
features=features,
cache_dir=cache_dir,
Expand Down Expand Up @@ -1416,7 +1422,7 @@ def from_generator(

@staticmethod
def from_json(
path_or_paths: Union[PathLike, list[PathLike]],
path_or_paths: Union[PathLike, Sequence_[PathLike]],
split: Optional[NamedSplit] = None,
features: Optional[Features] = None,
cache_dir: str = None,
Expand Down Expand Up @@ -1463,7 +1469,7 @@ def from_json(
from .io.json import JsonDatasetReader

return JsonDatasetReader(
path_or_paths,
_normalize_path_or_paths(path_or_paths),
split=split,
features=features,
cache_dir=cache_dir,
Expand All @@ -1475,7 +1481,7 @@ def from_json(

@staticmethod
def from_parquet(
path_or_paths: Union[PathLike, list[PathLike]],
path_or_paths: Union[PathLike, Sequence_[PathLike]],
split: Optional[NamedSplit] = None,
features: Optional[Features] = None,
cache_dir: str = None,
Expand Down Expand Up @@ -1557,7 +1563,7 @@ def from_parquet(
from .io.parquet import ParquetDatasetReader

return ParquetDatasetReader(
path_or_paths,
_normalize_path_or_paths(path_or_paths),
split=split,
features=features,
cache_dir=cache_dir,
Expand All @@ -1572,7 +1578,7 @@ def from_parquet(

@staticmethod
def from_text(
path_or_paths: Union[PathLike, list[PathLike]],
path_or_paths: Union[PathLike, Sequence_[PathLike]],
split: Optional[NamedSplit] = None,
features: Optional[Features] = None,
cache_dir: str = None,
Expand Down Expand Up @@ -1623,7 +1629,7 @@ def from_text(
from .io.text import TextDatasetReader

return TextDatasetReader(
path_or_paths,
_normalize_path_or_paths(path_or_paths),
split=split,
features=features,
cache_dir=cache_dir,
Expand Down
16 changes: 12 additions & 4 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3908,12 +3908,14 @@ def test_dataset_from_csv_split(split, csv_path, tmp_path):
assert dataset.split == split if split else "train"


@pytest.mark.parametrize("path_type", [str, list])
@pytest.mark.parametrize("path_type", [str, list, tuple])
def test_dataset_from_csv_path_type(path_type, csv_path, tmp_path):
if issubclass(path_type, str):
path = csv_path
elif issubclass(path_type, list):
path = [csv_path]
elif issubclass(path_type, tuple):
path = (csv_path,)
cache_dir = tmp_path / "cache"
expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
dataset = Dataset.from_csv(path, cache_dir=cache_dir)
Expand Down Expand Up @@ -3981,12 +3983,14 @@ def test_dataset_from_json_split(split, jsonl_path, tmp_path):
assert dataset.split == split if split else "train"


@pytest.mark.parametrize("path_type", [str, list])
@pytest.mark.parametrize("path_type", [str, list, tuple])
def test_dataset_from_json_path_type(path_type, jsonl_path, tmp_path):
if issubclass(path_type, str):
path = jsonl_path
elif issubclass(path_type, list):
path = [jsonl_path]
elif issubclass(path_type, tuple):
path = (jsonl_path,)
cache_dir = tmp_path / "cache"
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
dataset = Dataset.from_json(path, cache_dir=cache_dir)
Expand Down Expand Up @@ -4041,12 +4045,14 @@ def test_dataset_from_parquet_split(split, parquet_path, tmp_path):
assert dataset.split == split if split else "train"


@pytest.mark.parametrize("path_type", [str, list])
@pytest.mark.parametrize("path_type", [str, list, tuple])
def test_dataset_from_parquet_path_type(path_type, parquet_path, tmp_path):
if issubclass(path_type, str):
path = parquet_path
elif issubclass(path_type, list):
path = [parquet_path]
elif issubclass(path_type, tuple):
path = (parquet_path,)
cache_dir = tmp_path / "cache"
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
dataset = Dataset.from_parquet(path, cache_dir=cache_dir)
Expand Down Expand Up @@ -4100,12 +4106,14 @@ def test_dataset_from_text_split(split, text_path, tmp_path):
assert dataset.split == split if split else "train"


@pytest.mark.parametrize("path_type", [str, list])
@pytest.mark.parametrize("path_type", [str, list, tuple])
def test_dataset_from_text_path_type(path_type, text_path, tmp_path):
if issubclass(path_type, str):
path = text_path
elif issubclass(path_type, list):
path = [text_path]
elif issubclass(path_type, tuple):
path = (text_path,)
cache_dir = tmp_path / "cache"
expected_features = {"text": "string"}
dataset = Dataset.from_text(path, cache_dir=cache_dir)
Expand Down