Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,6 +1707,7 @@ def load_from_disk(
dataset_path: PathLike,
keep_in_memory: Optional[bool] = None,
storage_options: Optional[dict] = None,
progress_bar: Optional[bool] = None,
) -> "Dataset":
"""
Loads a dataset that was previously saved using [`save_to_disk`] from a dataset directory, or from a
Expand Down Expand Up @@ -1791,15 +1792,19 @@ def load_from_disk(
)
keep_in_memory = keep_in_memory if keep_in_memory is not None else is_small_dataset(dataset_size)
table_cls = InMemoryTable if keep_in_memory else MemoryMappedTable

# if the dataset is small, we enable the progress bar by default, otherwise we disable it by default. The user can still override this default behavior by explicitly setting `progress_bar` to `True` or `False`.
if progress_bar is None:
disable = len(state["_data_files"]) <= 16 or None
else:
disable = not bool(progress_bar)
arrow_table = concat_tables(
thread_map(
table_cls.from_file,
[posixpath.join(dest_dataset_path, data_file["filename"]) for data_file in state["_data_files"]],
tqdm_class=hf_tqdm,
desc="Loading dataset from disk",
# set `disable=None` rather than `disable=False` by default to disable progress bar when no TTY attached
disable=len(state["_data_files"]) <= 16 or None,
disable=disable,
)
)

Expand Down
2 changes: 2 additions & 0 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,7 @@ def load_from_disk(
dataset_dict_path: PathLike,
keep_in_memory: Optional[bool] = None,
storage_options: Optional[dict] = None,
progress_bar: Optional[bool] = None,
) -> "DatasetDict":
"""
Load a dataset that was previously saved using [`save_to_disk`] from a filesystem using `fsspec.spec.AbstractFileSystem`.
Expand Down Expand Up @@ -1422,6 +1423,7 @@ def load_from_disk(
dataset_dict_split_path,
keep_in_memory=keep_in_memory,
storage_options=storage_options,
progress_bar=progress_bar,
)
return dataset_dict

Expand Down
13 changes: 10 additions & 3 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,7 +1520,10 @@ def load_dataset(


def load_from_disk(
dataset_path: PathLike, keep_in_memory: Optional[bool] = None, storage_options: Optional[dict] = None
dataset_path: PathLike,
keep_in_memory: Optional[bool] = None,
storage_options: Optional[dict] = None,
progress_bar: Optional[bool] = None,
) -> Union[Dataset, DatasetDict]:
"""
Loads a dataset that was previously saved using [`~Dataset.save_to_disk`] from a dataset directory, or
Expand Down Expand Up @@ -1560,9 +1563,13 @@ def load_from_disk(
if fs.isfile(posixpath.join(dataset_path, config.DATASET_INFO_FILENAME)) and fs.isfile(
posixpath.join(dataset_path, config.DATASET_STATE_JSON_FILENAME)
):
return Dataset.load_from_disk(dataset_path, keep_in_memory=keep_in_memory, storage_options=storage_options)
return Dataset.load_from_disk(
dataset_path, keep_in_memory=keep_in_memory, storage_options=storage_options, progress_bar=progress_bar
)
elif fs.isfile(posixpath.join(dataset_path, config.DATASETDICT_JSON_FILENAME)):
return DatasetDict.load_from_disk(dataset_path, keep_in_memory=keep_in_memory, storage_options=storage_options)
return DatasetDict.load_from_disk(
dataset_path, keep_in_memory=keep_in_memory, storage_options=storage_options, progress_bar=progress_bar
)
else:
raise FileNotFoundError(
f"Directory {dataset_path} is neither a `Dataset` directory nor a `DatasetDict` directory."
Expand Down