Skip to content

Commit 1f5f0ec

Browse files
committed
Keep the training data continuous and the total batch size constant regardless of changes in the replica world size.
1 parent abcff8a commit 1f5f0ec

File tree

7 files changed

+1548
-11
lines changed

7 files changed

+1548
-11
lines changed

torchft/data.py

Lines changed: 260 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,269 @@
1414
dataloader frequently to avoid duplicate batches.
1515
"""
1616

17-
from typing import Optional
17+
import math
18+
from collections.abc import Iterator
19+
from typing import Iterable, Optional, TypeVar, Union
1820

21+
import torch
1922
import torch.distributed as dist
2023
from torch.utils import data
24+
from torch.utils.data.dataset import Dataset
25+
from torch.utils.data.sampler import BatchSampler, Sampler
26+
27+
_T_co = TypeVar("_T_co", covariant=True)
28+
29+
30+
class SkipDistributedSampler(Sampler[_T_co]):
31+
def __init__(
32+
self,
33+
dataset: Dataset,
34+
num_replicas: Optional[int] = None,
35+
rank: Optional[int] = None,
36+
shuffle: bool = True,
37+
seed: int = 0,
38+
drop_last: bool = False,
39+
skip_samples: int = 0,
40+
) -> None:
41+
if num_replicas is None:
42+
if not dist.is_available():
43+
raise RuntimeError("Requires distributed package to be available")
44+
num_replicas = dist.get_world_size()
45+
if rank is None:
46+
if not dist.is_available():
47+
raise RuntimeError("Requires distributed package to be available")
48+
rank = dist.get_rank()
49+
if rank >= num_replicas or rank < 0:
50+
raise ValueError(
51+
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
52+
)
53+
self.dataset = dataset
54+
self.num_replicas = num_replicas
55+
self.rank = rank
56+
self.epoch = 0
57+
self.drop_last = drop_last
58+
self.skip_samples = skip_samples
59+
# If the dataset length is evenly divisible by # of replicas, then there
60+
# is no need to drop any data, since the dataset will be split equally.
61+
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
62+
# Split to nearest available length that is evenly divisible.
63+
# This is to ensure each rank receives the same amount of data when
64+
# using this Sampler.
65+
self.num_samples = math.ceil(
66+
(len(self.dataset) - self.skip_samples - self.num_replicas)
67+
/ self.num_replicas # type: ignore[arg-type]
68+
)
69+
else:
70+
self.num_samples = math.ceil(
71+
(len(self.dataset) - self.skip_samples) / self.num_replicas
72+
) # type: ignore[arg-type]
73+
self.total_size = self.num_samples * self.num_replicas
74+
self.shuffle = shuffle
75+
self.seed = seed
76+
77+
def __iter__(self) -> Iterator[_T_co]:
78+
if self.shuffle:
79+
# deterministically shuffle based on epoch and seed
80+
g = torch.Generator()
81+
g.manual_seed(self.seed + self.epoch)
82+
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
83+
else:
84+
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
85+
86+
if not self.drop_last:
87+
indices = indices[self.skip_samples : len(indices)]
88+
# add extra samples to make it evenly divisible
89+
padding_size = self.total_size - len(indices)
90+
if padding_size <= len(indices):
91+
indices += indices[:padding_size]
92+
else:
93+
indices += (indices * math.ceil(padding_size / len(indices)))[
94+
:padding_size
95+
]
96+
else:
97+
# remove tail of data to make it evenly divisible.
98+
indices = indices[self.skip_samples : self.skip_samples + self.total_size]
99+
if len(indices) != self.total_size:
100+
raise AssertionError(
101+
f"Number of indices ({len(indices)}) does not match total_size ({self.total_size})"
102+
)
103+
104+
# subsample
105+
indices = indices[self.rank : self.total_size : self.num_replicas]
106+
if len(indices) != self.num_samples:
107+
raise AssertionError(
108+
f"Number of subsampled indices ({len(indices)}) does not match num_samples ({self.num_samples})"
109+
)
110+
111+
# pyrefly: ignore # bad-return
112+
return iter(indices)
113+
114+
def __len__(self) -> int:
115+
return self.num_samples
116+
117+
def set_epoch(self, epoch: int) -> None:
118+
r"""
119+
Set the epoch for this sampler.
120+
121+
When :attr:`shuffle=True`, this ensures all replicas
122+
use a different random ordering for each epoch. Otherwise, the next iteration of this
123+
sampler will yield the same ordering.
124+
125+
Args:
126+
epoch (int): Epoch number.
127+
"""
128+
self.epoch = epoch
129+
130+
131+
class DistributedBatchSampler(Sampler[list[int]]):
132+
r"""Wraps a BatchSampler to distribute batches across multiple processes in distributed training.
133+
134+
Each process gets a subset of batches based on its rank and the total number of replicas.
135+
This is useful for distributed training where each process should work on different batches
136+
to avoid data duplication.
137+
138+
Args:
139+
sampler (Sampler or Iterable): Base sampler. Can be any iterable object
140+
batch_size (int): Size of mini-batch.
141+
drop_last (bool): If ``True``, the sampler will drop the last batch if
142+
its size would be less than ``batch_size``
143+
num_replicas (int): Number of processes participating in distributed training.
144+
rank (int): Rank of the current process within num_replicas.
145+
Should be in range [0, num_replicas - 1].
146+
even_batches (bool): If ``True``, ensures all ranks get exactly the same number
147+
of batches by potentially dropping some batches. If ``False``, some ranks
148+
may get one extra batch. Default: ``True``.
149+
150+
Example:
151+
>>> # For a dataset with indices 0-20, batch_size=2, num_replicas=2
152+
>>> # All batches would be: [[0,1], [2,3], [4,5], [6,7], [8,9], [10,11], ...]
153+
>>>
154+
>>> # With even_batches=False (original behavior):
155+
>>> # rank=0 gets batches: [[0,1], [4,5], [8,9], [12,13], [16,17], [20]] (6 batches)
156+
>>> # rank=1 gets batches: [[2,3], [6,7], [10,11], [14,15], [18,19]] (5 batches)
157+
>>> sampler_rank0 = DistributedBatchSampler(
158+
... SequentialSampler(range(21)), batch_size=2, drop_last=False,
159+
... num_replicas=2, rank=0, even_batches=False
160+
... )
161+
>>> list(sampler_rank0)
162+
[[0, 1], [4, 5], [8, 9], [12, 13], [16, 17], [20]]
163+
>>>
164+
>>> # With even_batches=True (default behavior):
165+
>>> # Both ranks get exactly 5 batches (drops the last batch [20])
166+
>>> # rank=0 gets batches: [[0,1], [4,5], [8,9], [12,13], [16,17]] (5 batches)
167+
>>> # rank=1 gets batches: [[2,3], [6,7], [10,11], [14,15], [18,19]] (5 batches)
168+
>>> sampler_rank0_even = DistributedBatchSampler(
169+
... SequentialSampler(range(21)), batch_size=2, drop_last=False,
170+
... num_replicas=2, rank=0, even_batches=True
171+
... )
172+
>>> list(sampler_rank0_even)
173+
[[0, 1], [4, 5], [8, 9], [12, 13], [16, 17]]
174+
"""
175+
176+
def __init__(
177+
self,
178+
sampler: Union[Sampler[int], Iterable[int]],
179+
batch_size: int,
180+
drop_last: bool,
181+
num_replicas: int = 1,
182+
rank: int = 0,
183+
even_batches: bool = True,
184+
) -> None:
185+
# Validate batch_size
186+
if (
187+
not isinstance(batch_size, int)
188+
or isinstance(batch_size, bool)
189+
or batch_size <= 0
190+
):
191+
raise ValueError(
192+
f"batch_size should be a positive integer value, but got batch_size={batch_size}"
193+
)
194+
195+
# Validate drop_last
196+
if not isinstance(drop_last, bool):
197+
raise ValueError(
198+
f"drop_last should be a boolean value, but got drop_last={drop_last}"
199+
)
200+
201+
# Validate num_replicas
202+
if not isinstance(num_replicas, int) or num_replicas <= 0:
203+
raise ValueError(
204+
f"num_replicas should be a positive integer value, but got num_replicas={num_replicas}"
205+
)
206+
207+
# Validate rank
208+
if not isinstance(rank, int) or rank < 0 or rank >= num_replicas:
209+
raise ValueError(
210+
f"rank should be an integer in range [0, {num_replicas - 1}], but got rank={rank}"
211+
)
212+
213+
# Validate even_batches
214+
if not isinstance(even_batches, bool):
215+
raise ValueError(
216+
f"even_batches should be a boolean value, but got even_batches={even_batches}"
217+
)
218+
219+
self.sampler = sampler
220+
self.batch_size = batch_size
221+
self.drop_last = drop_last
222+
self.num_replicas = num_replicas
223+
self.rank = rank
224+
self.even_batches = even_batches
225+
226+
# Create a BatchSampler to generate all batches
227+
self.batch_sampler = BatchSampler(sampler, batch_size, drop_last)
228+
229+
def __iter__(self) -> Iterator[list[int]]:
230+
if self.even_batches:
231+
# When even_batches=True, ensure all ranks get the same number of batches
232+
# by potentially dropping some batches
233+
all_batches = list(self.batch_sampler)
234+
total_batches = len(all_batches)
235+
236+
# Calculate how many batches each rank should get to make them even
237+
batches_per_rank = total_batches // self.num_replicas
238+
239+
# Only consider the first batches_per_rank * num_replicas batches
240+
# This ensures even distribution
241+
total_even_batches = batches_per_rank * self.num_replicas
242+
243+
batch_idx = 0
244+
for batch in all_batches:
245+
if batch_idx >= total_even_batches:
246+
# Stop yielding once we've exhausted the even batches
247+
break
248+
# Only yield batches that belong to current rank
249+
if batch_idx % self.num_replicas == self.rank:
250+
yield batch
251+
batch_idx += 1
252+
else:
253+
# Original behavior when even_batches=False
254+
batch_idx = 0
255+
for batch in self.batch_sampler:
256+
# Only yield batches that belong to current rank
257+
if batch_idx % self.num_replicas == self.rank:
258+
yield batch
259+
batch_idx += 1
260+
261+
def __len__(self) -> int:
262+
# Calculate total number of batches from BatchSampler
263+
total_batches = len(self.batch_sampler) # type: ignore[arg-type]
264+
265+
if self.even_batches:
266+
# When even_batches=True, all ranks get exactly the same number of batches
267+
return total_batches // self.num_replicas
268+
else:
269+
# Original behavior when even_batches=False
270+
# Each rank gets approximately total_batches // num_replicas batches
271+
# The remaining batches are distributed among the first few ranks
272+
batches_per_rank = total_batches // self.num_replicas
273+
remaining_batches = total_batches % self.num_replicas
274+
275+
# Current rank gets one extra batch if it's among the first 'remaining_batches' ranks
276+
if self.rank < remaining_batches:
277+
return batches_per_rank + 1
278+
else:
279+
return batches_per_rank
21280

22281

23282
# pyre-fixme[24]: expected generic parameter

0 commit comments

Comments
 (0)