Skip to content

Commit f7c04b5

Browse files
committed
synchronized split_dataset_by_node
1 parent 91f96a0 commit f7c04b5

File tree

3 files changed

+157
-32
lines changed

3 files changed

+157
-32
lines changed

src/datasets/arrow_dataset.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6692,9 +6692,15 @@ def iter_random_indices():
66926692
return concatenated_datasets.select(indices, **kwargs)
66936693

66946694

6695-
def _split_by_node_map_style_dataset(dataset: Dataset, rank: int, world_size: int) -> Dataset:
6695+
def _split_by_node_map_style_dataset(
6696+
dataset: Dataset,
6697+
rank: int,
6698+
world_size: int,
6699+
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
6700+
) -> Dataset:
66966701
"""
6697-
Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`.
6702+
Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`, with each
6703+
rank having the same number of examples thanks to the `stopping_strategy`.
66986704
Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset.
66996705
To maximize data loading throughput, chunks are made of contiguous data on disk if possible.
67006706
@@ -6709,7 +6715,15 @@ def _split_by_node_map_style_dataset(dataset: Dataset, rank: int, world_size: in
67096715
Returns:
67106716
[`Dataset`]: The dataset to be used on the node at rank `rank`.
67116717
"""
6712-
return dataset.shard(num_shards=world_size, index=rank, contiguous=True)
6718+
shard = dataset.shard(num_shards=world_size, index=rank, contiguous=True)
6719+
# Make sure all the shards have the same number of examples:
6720+
# - first_exhausted: len() = len(dataset) // world_size
6721+
# - all_exhausted: len() = len(dataset) // world_size + 1
6722+
if len(shard) == len(dataset) // world_size + 1 and stopping_strategy == "first_exhausted":
6723+
shard = shard.select(range(len(dataset) // world_size))
6724+
if len(shard) == len(dataset) // world_size and stopping_strategy == "all_exhausted":
6725+
shard = _concatenate_map_style_datasets([shard, shard.select([0])])
6726+
return shard
67136727

67146728

67156729
# This is outside Dataset.filter as it needs to be picklable for multiprocessing

src/datasets/distributed.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TypeVar
1+
from typing import Literal, TypeVar
22

33
from .arrow_dataset import Dataset, _split_by_node_map_style_dataset
44
from .iterable_dataset import IterableDataset, _split_by_node_iterable_dataset
@@ -7,20 +7,35 @@
77
DatasetType = TypeVar("DatasetType", Dataset, IterableDataset)
88

99

10-
def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> DatasetType:
10+
def split_dataset_by_node(
11+
dataset: DatasetType,
12+
rank: int,
13+
world_size: int,
14+
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
15+
) -> DatasetType:
1116
"""
12-
Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`.
17+
Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`, with each
18+
rank having the same number of examples thanks to the `stopping_strategy`.
19+
20+
The stopping strategy allows each node to have the same number of examples:
21+
22+
- "first_exhausted": stop when the first node runs of of data, and discard the extra data in the other nodes
23+
- "all_exhausted": stop when the last node runs out of data, and other nodes may reuse their data to compensate
1324
1425
For map-style datasets:
1526
1627
Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset.
1728
To maximize data loading throughput, chunks are made of contiguous data on disk if possible.
29+
This doesn't need communication between nodes, since each node knows how many examples
30+
are available and can discard or reuse up to one example accordingly.
1831
1932
For iterable datasets:
2033
21-
If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`),
22-
then the shards are evenly assigned across the nodes, which is the most optimized.
23-
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.
34+
The shards are evenly assigned across the nodes.
35+
To maximize data loading throughput, each nodes has its own data and there is no overlap between nodes.
36+
The stopping strategy has less impact at the end of training if the dataset has a number of shards that is
37+
a factor of `world_size` (e.g. if `dataset.num_shards % world_size == 0`), since each node has roughly
38+
the same amount of data available. Nodes communicate using torch distributed to decide when to stop.
2439
2540
Args:
2641
dataset ([`Dataset`] or [`IterableDataset`]):
@@ -34,6 +49,10 @@ def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> D
3449
[`Dataset`] or [`IterableDataset`]: The dataset to be used on the node at rank `rank`.
3550
"""
3651
if isinstance(dataset, Dataset):
37-
return _split_by_node_map_style_dataset(dataset, rank=rank, world_size=world_size)
52+
return _split_by_node_map_style_dataset(
53+
dataset, rank=rank, world_size=world_size, stopping_strategy=stopping_strategy
54+
)
3855
else:
39-
return _split_by_node_iterable_dataset(dataset, rank=rank, world_size=world_size)
56+
return _split_by_node_iterable_dataset(
57+
dataset, rank=rank, world_size=world_size, stopping_strategy=stopping_strategy
58+
)

src/datasets/iterable_dataset.py

Lines changed: 113 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2110,6 +2110,95 @@ def num_shards(self) -> int:
21102110
return self.ex_iterable.num_shards
21112111

21122112

2113+
class SyncedDistributedExamplesIterable(_BaseExamplesIterable):
2114+
def __init__(
2115+
self,
2116+
ex_iterable: _BaseExamplesIterable,
2117+
rank: int,
2118+
world_size: int,
2119+
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
2120+
):
2121+
super().__init__()
2122+
self.ex_iterable = ex_iterable
2123+
self.rank = rank
2124+
self.world_size = world_size
2125+
self.stopping_strategy = stopping_strategy
2126+
# if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted
2127+
# if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once
2128+
import torch
2129+
2130+
self.bool_strategy_func = torch.all if stopping_strategy == "all_exhausted" else torch.any
2131+
2132+
@property
2133+
def iter_arrow(self):
2134+
if self.ex_iterable.iter_arrow:
2135+
return self._iter_arrow
2136+
2137+
@property
2138+
def is_typed(self):
2139+
return self.ex_iterable.is_typed
2140+
2141+
@property
2142+
def features(self):
2143+
return self.ex_iterable.features
2144+
2145+
def _init_state_dict(self) -> dict:
2146+
self._state_dict = self.ex_iterable._init_state_dict()
2147+
return self._state_dict
2148+
2149+
def __iter__(self):
2150+
import torch
2151+
import torch.distributed as dist
2152+
2153+
is_exhausted = torch.zeros(self.world_size, dtype=torch.bool)
2154+
while True:
2155+
for key, example in self.ex_iterable:
2156+
yield key, example
2157+
dist.all_reduce(is_exhausted)
2158+
if self.bool_strategy_func(is_exhausted):
2159+
return
2160+
is_exhausted[self.rank] = True
2161+
if self._state_dict is not None:
2162+
self._state_dict = self.ex_iterable._init_state_dict()
2163+
2164+
def _iter_arrow(self) -> Iterator[tuple[Key, pa.Table]]:
2165+
import torch
2166+
import torch.distributed as dist
2167+
2168+
is_exhausted = torch.zeros(self.world_size, dtype=torch.bool)
2169+
while True:
2170+
for key, pa_table in self.ex_iterable._iter_arrow():
2171+
yield key, pa_table
2172+
dist.all_reduce(is_exhausted)
2173+
if self.bool_strategy_func(is_exhausted):
2174+
return
2175+
is_exhausted[self.rank] = True
2176+
if self._state_dict is not None:
2177+
self._state_dict = self.ex_iterable._init_state_dict()
2178+
2179+
def shuffle_data_sources(self, generator: np.random.Generator) -> "SyncedDistributedExamplesIterable":
2180+
"""Shuffle the wrapped examples iterable."""
2181+
return SyncedDistributedExamplesIterable(
2182+
self.ex_iterable.shuffle_data_sources(generator),
2183+
rank=self.rank,
2184+
world_size=self.world_size,
2185+
stopping_strategy=self.stopping_strategy,
2186+
)
2187+
2188+
def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SyncedDistributedExamplesIterable":
2189+
"""Keep only the requested shard."""
2190+
return SyncedDistributedExamplesIterable(
2191+
self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
2192+
rank=self.rank,
2193+
world_size=self.world_size,
2194+
stopping_strategy=self.stopping_strategy,
2195+
)
2196+
2197+
@property
2198+
def num_shards(self) -> int:
2199+
return self.ex_iterable.num_shards
2200+
2201+
21132202
@dataclass
21142203
class ShufflingConfig:
21152204
generator: np.random.Generator
@@ -2120,6 +2209,7 @@ class ShufflingConfig:
21202209
class DistributedConfig:
21212210
rank: int
21222211
world_size: int
2212+
stopping_strategy: Literal["first_exhausted", "all_exhausted"]
21232213

21242214

21252215
def _maybe_add_torch_iterable_dataset_parent_class(cls):
@@ -2466,6 +2556,7 @@ def _prepare_ex_iterable_for_iteration(
24662556
self._formatting
24672557
and (ex_iterable.iter_arrow or self._formatting.is_table)
24682558
or (self.features and ex_iterable.features != self.features)
2559+
or self._distributed
24692560
):
24702561
ex_iterable = RebatchedArrowExamplesIterable(
24712562
ex_iterable, batch_size=batch_size, drop_last_batch=drop_last_batch
@@ -2478,25 +2569,21 @@ def _prepare_ex_iterable_for_iteration(
24782569
if self._distributed:
24792570
rank = self._distributed.rank
24802571
world_size = self._distributed.world_size
2481-
if ex_iterable.num_shards % world_size == 0:
2482-
if self._is_main_process():
2483-
num_shards_per_node = ex_iterable.num_shards // world_size
2484-
plural = "s" if num_shards_per_node > 1 else ""
2485-
logger.info(
2486-
f"Assigning {num_shards_per_node} shard{plural} (or data source{plural}) of the dataset to each node."
2487-
)
2488-
ex_iterable = ex_iterable.shard_data_sources(num_shards=world_size, index=rank, contiguous=False)
2489-
else:
2490-
if self._is_main_process():
2491-
logger.info(
2492-
f"Assigning 1 out of {world_size} examples of the dataset to each node. The others are skipped during the iteration."
2493-
)
2494-
logger.info(
2495-
f"It is more optimized to distribute the dataset shards (or data sources) across nodes. "
2496-
f"You can do that by using a dataset with number of shards that is a factor of world_size={world_size}. "
2497-
f"The current dataset has {ex_iterable.num_shards} which is not a factor of {world_size}"
2498-
)
2499-
ex_iterable = StepExamplesIterable(ex_iterable, step=world_size, offset=rank)
2572+
if self._is_main_process():
2573+
num_shards_per_node = ex_iterable.num_shards // world_size
2574+
if ex_iterable.num_shards % world_size == 0:
2575+
num_shards_per_node = f"{num_shards_per_node}-{num_shards_per_node + 1}"
2576+
plural = "s" if str(num_shards_per_node) != "1" else ""
2577+
logger.info(
2578+
f"Assigning {num_shards_per_node} shard{plural} (or data source{plural}) of the dataset to each node."
2579+
)
2580+
ex_iterable = ex_iterable.shard_data_sources(num_shards=world_size, index=rank, contiguous=False)
2581+
ex_iterable = SyncedDistributedExamplesIterable(
2582+
ex_iterable,
2583+
rank=self._distributed.rank,
2584+
world_size=self._distributed.world_size,
2585+
stopping_strategy=self._distributed.stopping_strategy,
2586+
)
25002587

25012588
if self._formatting or (self.features and ex_iterable.features != self.features):
25022589
ex_iterable = FormattedExamplesIterable(
@@ -4662,7 +4749,12 @@ def _interleave_iterable_datasets(
46624749
return IterableDataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id)
46634750

46644751

4665-
def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_size: int) -> IterableDataset:
4752+
def _split_by_node_iterable_dataset(
4753+
dataset: IterableDataset,
4754+
rank: int,
4755+
world_size: int,
4756+
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
4757+
) -> IterableDataset:
46664758
"""
46674759
Split an iterable dataset for the node at rank `rank` in a pool of nodes of size `world_size`.
46684760
@@ -4684,7 +4776,7 @@ def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_s
46844776
if dataset._distributed:
46854777
rank = world_size * dataset._distributed.rank + rank
46864778
world_size = world_size * dataset._distributed.world_size
4687-
distributed = DistributedConfig(rank=rank, world_size=world_size)
4779+
distributed = DistributedConfig(rank=rank, world_size=world_size, stopping_strategy=stopping_strategy)
46884780
return IterableDataset(
46894781
ex_iterable=dataset._ex_iterable,
46904782
info=dataset._info.copy(),

0 commit comments

Comments
 (0)