Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Dataset.map to reuse cache files mapped with different num_proc #7434

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
5341540
Refactor Dataset.map to reuse cache files mapped with different num_proc
ringohoffman Mar 4, 2025
bdc17c9
Only give reprocessing message doing a partial remap
ringohoffman Mar 4, 2025
d7c63fd
Update logging message to account for if a cache file will be written…
ringohoffman Mar 4, 2025
0df4132
Refactor string_to_dict to return None if there is no match instead o…
ringohoffman Mar 4, 2025
7f50b98
Merge branch 'return-none-if-string_to_dict-no-match' into reuse-cach…
ringohoffman Mar 4, 2025
79dc83b
Simplify existing existing_cache_file_map with string_to_dict
ringohoffman Mar 4, 2025
bb7f9b5
Set initial value if there are already existing cache files
ringohoffman Mar 4, 2025
dafe4f2
Merge branch 'main' into return-none-if-string_to_dict-no-match
ringohoffman Mar 5, 2025
e2c1a5c
Merge branch 'return-none-if-string_to_dict-no-match' into reuse-cach…
ringohoffman Mar 5, 2025
c82cab4
Allow for source_url_fields to be None
ringohoffman Mar 7, 2025
28d82dc
Merge branch 'main' into return-none-if-string_to_dict-no-match
ringohoffman Mar 7, 2025
71b6d16
Merge branch 'return-none-if-string_to_dict-no-match' into reuse-cach…
ringohoffman Mar 9, 2025
8cc0186
Merge branch 'main' into reuse-cache-on-different-num_proc
ringohoffman Mar 12, 2025
637c160
Add unicode escape to handle parsing string_to_dict in Windows paths
ringohoffman Mar 12, 2025
25c0015
Merge branch 'main' into reuse-cache-on-different-num_proc
lhoestq Mar 14, 2025
583c28e
Remove glob_pattern_to_regex
ringohoffman Mar 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 185 additions & 107 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import contextlib
import copy
import fnmatch
import glob
import inspect
import itertools
import json
Expand All @@ -27,12 +28,13 @@
import posixpath
import re
import shutil
import string
import sys
import tempfile
import time
import warnings
import weakref
from collections import Counter
from collections import Counter, defaultdict
from collections.abc import Mapping
from copy import deepcopy
from functools import partial, wraps
Expand Down Expand Up @@ -2964,6 +2966,11 @@ def map(
if num_proc is not None and num_proc <= 0:
raise ValueError("num_proc must be an integer > 0.")

string_formatter = string.Formatter()
fields = {field_name for _, field_name, _, _ in string_formatter.parse(suffix_template) if field_name}
if fields != {"rank", "num_proc"}:
raise ValueError(f"suffix_template must contain exactly the fields 'rank' and 'num_proc', got: {fields}")

# If the array is empty we do nothing (but we make sure to handle an empty indices mapping and remove the requested columns anyway)
if len(self) == 0:
if self._indices is not None: # empty indices mapping
Expand Down Expand Up @@ -3045,7 +3052,7 @@ def map(
cache_file_name = self._get_cache_file_path(new_fingerprint)
dataset_kwargs["cache_file_name"] = cache_file_name

def load_processed_shard_from_cache(shard_kwargs):
def load_processed_shard_from_cache(shard_kwargs: dict[str, Any]) -> Dataset:
"""Load a processed shard from cache if it exists, otherwise throw an error."""
shard = shard_kwargs["shard"]
# Check if we've already cached this computation (indexed by a hash)
Expand All @@ -3056,64 +3063,98 @@ def load_processed_shard_from_cache(shard_kwargs):
return Dataset.from_file(shard_kwargs["cache_file_name"], info=info, split=shard.split)
raise NonExistentDatasetError

num_shards = num_proc if num_proc is not None else 1
if batched and drop_last_batch:
pbar_total = len(self) // num_shards // batch_size * num_shards * batch_size
else:
pbar_total = len(self)
def pbar_total(num_shards: int, batch_size: Optional[int]) -> int:
total = len(self)
if len(existing_cache_files) < num_shards:
total -= len(existing_cache_files) * total // num_shards
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the total be the same even if some shards have already been computed ?

As a user I'd expect the progress bar to resume from where I was in this case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of subtracting it from the total, we can use the initial parameter:

import os
import datasets

dataset = datasets.load_dataset("ylecun/mnist")
cache_file_name="./cache/train.map"

dataset["train"].map(lambda x: x, cache_file_name=cache_file_name, num_proc=10)

os.remove("./cache/train_00001_of_00010.map")
os.remove("./cache/train_00002_of_00010.map")

dataset["train"].map(lambda x: x, cache_file_name=cache_file_name, num_proc=5)
Map (num_proc=5):  80%|████████  | 48000/60000 [00:00<?, ? examples/s]
Map (num_proc=5): 100%|██████████| 60000/60000 [00:00<00:00, 28528.31 examples/s]

See bb7f9b5

if batched and drop_last_batch:
batch_size = batch_size or 1
return total // num_shards // batch_size * num_shards * batch_size
return total

def get_existing_cache_file_map(
cache_file_name: Optional[str],
) -> dict[int, list[str]]:
cache_files_by_num_proc: dict[int, list[str]] = defaultdict(list)
if cache_file_name is None:
return cache_files_by_num_proc
if os.path.exists(cache_file_name):
cache_files_by_num_proc[1] = [cache_file_name]

suffix_pattern_parts: list[str] = []
for literal_text, field_name, format_spec, _ in string_formatter.parse(suffix_template):
suffix_pattern_parts.append(re.escape(literal_text))
if field_name:
# TODO: we may want to place restrictions on acceptable format_spec or we will fail to match
# someone's hexidecimal or scientific notation format 😵
suffix_pattern_parts.append(f"(?P<{field_name}>\\d+)")
suffix_pattern = "".join(suffix_pattern_parts)

cache_file_prefix, cache_file_ext = os.path.splitext(cache_file_name)
if not cache_file_ext:
raise ValueError(f"Expected cache_file_name to have an extension, but got: {cache_file_name}")

cache_file_pattern = "^" + re.escape(cache_file_prefix) + suffix_pattern + re.escape(cache_file_ext) + "$"
cache_file_regex = re.compile(cache_file_pattern)

for cache_file in glob.iglob(f"{cache_file_prefix}*{cache_file_ext}"):
if m := cache_file_regex.match(cache_file):
file_num_proc = int(m.group("num_proc"))
cache_files_by_num_proc[file_num_proc].append(cache_file)

return cache_files_by_num_proc

existing_cache_file_map = get_existing_cache_file_map(cache_file_name)

num_shards = num_proc or 1
if existing_cache_file_map:
# to avoid remapping when a different num_proc is given than when originally cached, update num_shards to
# what was used originally

def select_existing_cache_files(mapped_num_proc: int) -> tuple[float, ...]:
percent_missing = (mapped_num_proc - len(existing_cache_file_map[mapped_num_proc])) / mapped_num_proc
num_shards_diff = abs(mapped_num_proc - num_shards)
return (
percent_missing, # choose the most complete set of existing cache files
num_shards_diff, # then choose the mapped_num_proc closest to the current num_proc
mapped_num_proc, # finally, choose whichever mapped_num_proc is lower
)

shards_done = 0
if num_proc is None or num_proc == 1:
transformed_dataset = None
try:
transformed_dataset = load_processed_shard_from_cache(dataset_kwargs)
logger.info(f"Loading cached processed dataset at {dataset_kwargs['cache_file_name']}")
except NonExistentDatasetError:
pass
if transformed_dataset is None:
with hf_tqdm(
unit=" examples",
total=pbar_total,
desc=desc or "Map",
) as pbar:
for rank, done, content in Dataset._map_single(**dataset_kwargs):
if done:
shards_done += 1
logger.debug(f"Finished processing shard number {rank} of {num_shards}.")
transformed_dataset = content
else:
pbar.update(content)
assert transformed_dataset is not None, "Failed to retrieve the result from map"
# update fingerprint if the dataset changed
if transformed_dataset._fingerprint != self._fingerprint:
transformed_dataset._fingerprint = new_fingerprint
return transformed_dataset
else:
num_shards = min(existing_cache_file_map, key=select_existing_cache_files)

def format_cache_file_name(
cache_file_name: Optional[str],
rank: Union[int, Literal["*"]], # noqa: F722
) -> Optional[str]:
if not cache_file_name:
return cache_file_name
sep = cache_file_name.rindex(".")
base_name, extension = cache_file_name[:sep], cache_file_name[sep:]
if isinstance(rank, int):
cache_file_name = base_name + suffix_template.format(rank=rank, num_proc=num_proc) + extension
logger.info(f"Process #{rank} will write at {cache_file_name}")
else:
cache_file_name = (
base_name
+ suffix_template.replace("{rank:05d}", "{rank}").format(rank=rank, num_proc=num_proc)
+ extension
)
existing_cache_files = existing_cache_file_map.get(num_shards, [])

def format_cache_file_name(
cache_file_name: Optional[str],
rank: Union[int, Literal["*"]], # noqa: F722
) -> Optional[str]:
if not cache_file_name:
return cache_file_name

def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
new_fingerprint = new_fingerprint + suffix_template.format(rank=rank, num_proc=num_proc)
validate_fingerprint(new_fingerprint)
return new_fingerprint
cache_file_prefix, cache_file_ext = os.path.splitext(cache_file_name)
if not cache_file_ext:
raise ValueError(f"Expected cache_file_name to have an extension, but got: {cache_file_name}")

if isinstance(rank, int):
cache_file_name = (
cache_file_prefix + suffix_template.format(rank=rank, num_proc=num_shards) + cache_file_ext
)
logger.info(f"Process #{rank} will write at {cache_file_name}")
else:
# TODO: this assumes the format_spec of rank in suffix_template
cache_file_name = (
cache_file_prefix
+ suffix_template.replace("{rank:05d}", "{rank}").format(rank=rank, num_proc=num_shards)
+ cache_file_ext
)
return cache_file_name

def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
new_fingerprint = new_fingerprint + suffix_template.format(rank=rank, num_proc=num_shards)
validate_fingerprint(new_fingerprint)
return new_fingerprint

if num_proc is not None and num_proc > 1:
prev_env = deepcopy(os.environ)
# check if parallelism if off
# from https://github.com/huggingface/tokenizers/blob/bb668bc439dc34389b71dbb8ce0c597f15707b53/tokenizers/src/utils/parallelism.rs#L22
Expand All @@ -3128,9 +3169,17 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
):
logger.warning("Setting TOKENIZERS_PARALLELISM=false for forked processes.")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
else:
prev_env = os.environ

kwargs_per_job: list[Optional[dict[str, Any]]]
if num_shards == 1:
shards = [self]
kwargs_per_job = [dataset_kwargs]
else:
shards = [
self.shard(num_shards=num_proc, index=rank, contiguous=True, keep_in_memory=keep_in_memory)
for rank in range(num_proc)
self.shard(num_shards=num_shards, index=rank, contiguous=True, keep_in_memory=keep_in_memory)
for rank in range(num_shards)
]
kwargs_per_job = [
{
Expand All @@ -3144,60 +3193,89 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
for rank in range(num_shards)
]

transformed_shards = [None] * num_shards
for rank in range(num_shards):
try:
transformed_shards[rank] = load_processed_shard_from_cache(kwargs_per_job[rank])
kwargs_per_job[rank] = None
except NonExistentDatasetError:
pass

kwargs_per_job = [kwargs for kwargs in kwargs_per_job if kwargs is not None]

# We try to create a pool with as many workers as dataset not yet cached.
if kwargs_per_job:
if len(kwargs_per_job) < num_shards:
logger.info(
f"Reprocessing {len(kwargs_per_job)}/{num_shards} shards because some of them were missing from the cache."
)
with Pool(len(kwargs_per_job)) as pool:
os.environ = prev_env
logger.info(f"Spawning {num_proc} processes")
with hf_tqdm(
unit=" examples",
total=pbar_total,
desc=(desc or "Map") + f" (num_proc={num_proc})",
) as pbar:
transformed_shards: list[Optional[Dataset]] = [None] * num_shards
for rank in range(num_shards):
try:
job_kwargs = kwargs_per_job[rank]
assert job_kwargs is not None
transformed_shards[rank] = load_processed_shard_from_cache(job_kwargs)
kwargs_per_job[rank] = None
except NonExistentDatasetError:
pass

if unprocessed_kwargs_per_job := [kwargs for kwargs in kwargs_per_job if kwargs is not None]:
if len(unprocessed_kwargs_per_job) < num_shards:
logger.info(
f"Reprocessing {len(unprocessed_kwargs_per_job)}/{num_shards} shards because some of them were "
" missing from the cache."
)

with hf_tqdm(
unit=" examples",
total=pbar_total(num_shards, batch_size),
desc=(desc or "Map") + (f" (num_proc={num_proc})" if num_proc is not None and num_proc > 1 else ""),
) as pbar:
shards_done = 0

def check_if_shard_done(rank: Optional[int], done: bool, content: Union[Dataset, int]) -> None:
nonlocal shards_done
if done:
shards_done += 1
logger.debug(f"Finished processing shard number {rank} of {num_shards}.")
assert isinstance(content, Dataset)
transformed_shards[rank or 0] = content
else:
assert isinstance(content, int)
pbar.update(content)

if num_proc is not None and num_proc > 1:
with Pool(num_proc) as pool:
os.environ = prev_env
logger.info(f"Spawning {num_proc} processes")

for rank, done, content in iflatmap_unordered(
pool, Dataset._map_single, kwargs_iterable=kwargs_per_job
pool, Dataset._map_single, kwargs_iterable=unprocessed_kwargs_per_job
):
if done:
shards_done += 1
logger.debug(f"Finished processing shard number {rank} of {num_shards}.")
transformed_shards[rank] = content
else:
pbar.update(content)
pool.close()
pool.join()
# Avoids PermissionError on Windows (the error: https://github.com/huggingface/datasets/actions/runs/4026734820/jobs/6921621805)
for kwargs in kwargs_per_job:
del kwargs["shard"]
else:
logger.info(f"Loading cached processed dataset at {format_cache_file_name(cache_file_name, '*')}")
assert None not in transformed_shards, (
f"Failed to retrieve results from map: result list {transformed_shards} still contains None - at least one worker failed to return its results"
check_if_shard_done(rank, done, content)

pool.close()
pool.join()
else:
for unprocessed_kwargs in unprocessed_kwargs_per_job:
for rank, done, content in Dataset._map_single(**unprocessed_kwargs):
check_if_shard_done(rank, done, content)

# Avoids PermissionError on Windows (the error: https://github.com/huggingface/datasets/actions/runs/4026734820/jobs/6921621805)
for job_kwargs in unprocessed_kwargs_per_job:
if "shard" in job_kwargs:
del job_kwargs["shard"]
else:
logger.info(f"Loading cached processed dataset at {format_cache_file_name(cache_file_name, '*')}")

all_transformed_shards = [shard for shard in transformed_shards if shard is not None]
if len(transformed_shards) != len(all_transformed_shards):
raise ValueError(
f"Failed to retrieve results from map: result list {transformed_shards} still contains None - "
"at least one worker failed to return its results"
)
logger.info(f"Concatenating {num_proc} shards")
result = _concatenate_map_style_datasets(transformed_shards)
# update fingerprint if the dataset changed

if num_shards == 1:
result = all_transformed_shards[0]
else:
logger.info(f"Concatenating {num_shards} shards")
result = _concatenate_map_style_datasets(all_transformed_shards)

# update fingerprint if the dataset changed
result._fingerprint = (
new_fingerprint
if any(
transformed_shard._fingerprint != shard._fingerprint
for transformed_shard, shard in zip(transformed_shards, shards)
):
result._fingerprint = new_fingerprint
else:
result._fingerprint = self._fingerprint
return result
for transformed_shard, shard in zip(all_transformed_shards, shards)
)
else self._fingerprint
)

return result

@staticmethod
def _map_single(
Expand All @@ -3219,7 +3297,7 @@ def _map_single(
new_fingerprint: Optional[str] = None,
rank: Optional[int] = None,
offset: int = 0,
) -> Iterable[Tuple[int, bool, Union[int, "Dataset"]]]:
) -> Iterable[Tuple[Optional[int], bool, Union[int, "Dataset"]]]:
"""Apply a function to all the elements in the table (individually or in batches)
and update the table (if function does update examples).

Expand Down
Loading
Loading