@@ -2110,6 +2110,95 @@ def num_shards(self) -> int:
21102110 return self .ex_iterable .num_shards
21112111
21122112
2113+ class SyncedDistributedExamplesIterable (_BaseExamplesIterable ):
2114+ def __init__ (
2115+ self ,
2116+ ex_iterable : _BaseExamplesIterable ,
2117+ rank : int ,
2118+ world_size : int ,
2119+ stopping_strategy : Literal ["first_exhausted" , "all_exhausted" ] = "first_exhausted" ,
2120+ ):
2121+ super ().__init__ ()
2122+ self .ex_iterable = ex_iterable
2123+ self .rank = rank
2124+ self .world_size = world_size
2125+ self .stopping_strategy = stopping_strategy
2126+ # if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted
2127+ # if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once
2128+ import torch
2129+
2130+ self .bool_strategy_func = torch .all if stopping_strategy == "all_exhausted" else torch .any
2131+
2132+ @property
2133+ def iter_arrow (self ):
2134+ if self .ex_iterable .iter_arrow :
2135+ return self ._iter_arrow
2136+
2137+ @property
2138+ def is_typed (self ):
2139+ return self .ex_iterable .is_typed
2140+
2141+ @property
2142+ def features (self ):
2143+ return self .ex_iterable .features
2144+
2145+ def _init_state_dict (self ) -> dict :
2146+ self ._state_dict = self .ex_iterable ._init_state_dict ()
2147+ return self ._state_dict
2148+
2149+ def __iter__ (self ):
2150+ import torch
2151+ import torch .distributed as dist
2152+
2153+ is_exhausted = torch .zeros (self .world_size , dtype = torch .bool )
2154+ while True :
2155+ for key , example in self .ex_iterable :
2156+ yield key , example
2157+ dist .all_reduce (is_exhausted )
2158+ if self .bool_strategy_func (is_exhausted ):
2159+ return
2160+ is_exhausted [self .rank ] = True
2161+ if self ._state_dict is not None :
2162+ self ._state_dict = self .ex_iterable ._init_state_dict ()
2163+
2164+ def _iter_arrow (self ) -> Iterator [tuple [Key , pa .Table ]]:
2165+ import torch
2166+ import torch .distributed as dist
2167+
2168+ is_exhausted = torch .zeros (self .world_size , dtype = torch .bool )
2169+ while True :
2170+ for key , pa_table in self .ex_iterable ._iter_arrow ():
2171+ yield key , pa_table
2172+ dist .all_reduce (is_exhausted )
2173+ if self .bool_strategy_func (is_exhausted ):
2174+ return
2175+ is_exhausted [self .rank ] = True
2176+ if self ._state_dict is not None :
2177+ self ._state_dict = self .ex_iterable ._init_state_dict ()
2178+
2179+ def shuffle_data_sources (self , generator : np .random .Generator ) -> "SyncedDistributedExamplesIterable" :
2180+ """Shuffle the wrapped examples iterable."""
2181+ return SyncedDistributedExamplesIterable (
2182+ self .ex_iterable .shuffle_data_sources (generator ),
2183+ rank = self .rank ,
2184+ world_size = self .world_size ,
2185+ stopping_strategy = self .stopping_strategy ,
2186+ )
2187+
2188+ def shard_data_sources (self , num_shards : int , index : int , contiguous = True ) -> "SyncedDistributedExamplesIterable" :
2189+ """Keep only the requested shard."""
2190+ return SyncedDistributedExamplesIterable (
2191+ self .ex_iterable .shard_data_sources (num_shards , index , contiguous = contiguous ),
2192+ rank = self .rank ,
2193+ world_size = self .world_size ,
2194+ stopping_strategy = self .stopping_strategy ,
2195+ )
2196+
2197+ @property
2198+ def num_shards (self ) -> int :
2199+ return self .ex_iterable .num_shards
2200+
2201+
21132202@dataclass
21142203class ShufflingConfig :
21152204 generator : np .random .Generator
@@ -2120,6 +2209,7 @@ class ShufflingConfig:
21202209class DistributedConfig :
21212210 rank : int
21222211 world_size : int
2212+ stopping_strategy : Literal ["first_exhausted" , "all_exhausted" ]
21232213
21242214
21252215def _maybe_add_torch_iterable_dataset_parent_class (cls ):
@@ -2466,6 +2556,7 @@ def _prepare_ex_iterable_for_iteration(
24662556 self ._formatting
24672557 and (ex_iterable .iter_arrow or self ._formatting .is_table )
24682558 or (self .features and ex_iterable .features != self .features )
2559+ or self ._distributed
24692560 ):
24702561 ex_iterable = RebatchedArrowExamplesIterable (
24712562 ex_iterable , batch_size = batch_size , drop_last_batch = drop_last_batch
@@ -2478,25 +2569,21 @@ def _prepare_ex_iterable_for_iteration(
24782569 if self ._distributed :
24792570 rank = self ._distributed .rank
24802571 world_size = self ._distributed .world_size
2481- if ex_iterable .num_shards % world_size == 0 :
2482- if self ._is_main_process ():
2483- num_shards_per_node = ex_iterable .num_shards // world_size
2484- plural = "s" if num_shards_per_node > 1 else ""
2485- logger .info (
2486- f"Assigning { num_shards_per_node } shard{ plural } (or data source{ plural } ) of the dataset to each node."
2487- )
2488- ex_iterable = ex_iterable .shard_data_sources (num_shards = world_size , index = rank , contiguous = False )
2489- else :
2490- if self ._is_main_process ():
2491- logger .info (
2492- f"Assigning 1 out of { world_size } examples of the dataset to each node. The others are skipped during the iteration."
2493- )
2494- logger .info (
2495- f"It is more optimized to distribute the dataset shards (or data sources) across nodes. "
2496- f"You can do that by using a dataset with number of shards that is a factor of world_size={ world_size } . "
2497- f"The current dataset has { ex_iterable .num_shards } which is not a factor of { world_size } "
2498- )
2499- ex_iterable = StepExamplesIterable (ex_iterable , step = world_size , offset = rank )
2572+ if self ._is_main_process ():
2573+ num_shards_per_node = ex_iterable .num_shards // world_size
2574+ if ex_iterable .num_shards % world_size == 0 :
2575+ num_shards_per_node = f"{ num_shards_per_node } -{ num_shards_per_node + 1 } "
2576+ plural = "s" if str (num_shards_per_node ) != "1" else ""
2577+ logger .info (
2578+ f"Assigning { num_shards_per_node } shard{ plural } (or data source{ plural } ) of the dataset to each node."
2579+ )
2580+ ex_iterable = ex_iterable .shard_data_sources (num_shards = world_size , index = rank , contiguous = False )
2581+ ex_iterable = SyncedDistributedExamplesIterable (
2582+ ex_iterable ,
2583+ rank = self ._distributed .rank ,
2584+ world_size = self ._distributed .world_size ,
2585+ stopping_strategy = self ._distributed .stopping_strategy ,
2586+ )
25002587
25012588 if self ._formatting or (self .features and ex_iterable .features != self .features ):
25022589 ex_iterable = FormattedExamplesIterable (
@@ -4662,7 +4749,12 @@ def _interleave_iterable_datasets(
46624749 return IterableDataset (ex_iterable = ex_iterable , info = info , split = split , token_per_repo_id = token_per_repo_id )
46634750
46644751
4665- def _split_by_node_iterable_dataset (dataset : IterableDataset , rank : int , world_size : int ) -> IterableDataset :
4752+ def _split_by_node_iterable_dataset (
4753+ dataset : IterableDataset ,
4754+ rank : int ,
4755+ world_size : int ,
4756+ stopping_strategy : Literal ["first_exhausted" , "all_exhausted" ] = "first_exhausted" ,
4757+ ) -> IterableDataset :
46664758 """
46674759 Split an iterable dataset for the node at rank `rank` in a pool of nodes of size `world_size`.
46684760
@@ -4684,7 +4776,7 @@ def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_s
46844776 if dataset ._distributed :
46854777 rank = world_size * dataset ._distributed .rank + rank
46864778 world_size = world_size * dataset ._distributed .world_size
4687- distributed = DistributedConfig (rank = rank , world_size = world_size )
4779+ distributed = DistributedConfig (rank = rank , world_size = world_size , stopping_strategy = stopping_strategy )
46884780 return IterableDataset (
46894781 ex_iterable = dataset ._ex_iterable ,
46904782 info = dataset ._info .copy (),
0 commit comments