Skip to content

Commit 73840cc

Browse files
eric-czechclaude
andcommitted
levanter: add randomize_epochs flag to MixtureDataset
When set, each pass through a finite mixture component uses an independent permutation of its samples instead of the natural-order cycle (`raw_idx % L`). The flag defaults to False so existing behavior is preserved. The permutation is built by a new module-level helper `_compute_epoch_assignment` that derives a per-(dataset_id, epoch) Feistel permutation via `fold_in(fold_in(key, dataset_id), epoch)`, cached per instance. `_remap_indices` now takes the dataset id and dispatches to the per-epoch permutation when the flag is on; under FIRST_STOP_STRATEGY the flag is a no-op since no component completes more than one pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent e0d23b3 commit 73840cc

2 files changed

Lines changed: 72 additions & 9 deletions

File tree

lib/levanter/src/levanter/data/mixture.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from levanter.utils.jax_utils import local_cpu_mesh
1616

1717
from levanter.data import AsyncDataset
18+
from levanter.data._prp import Permutation
1819
from levanter.schedule import BatchSchedule
1920
from levanter.utils.index import Index
2021
from levanter.utils.thread_utils import blocking_wait, future_from_value
@@ -47,6 +48,11 @@ class MixtureDataset(AsyncDataset[T]):
4748
- FIRST_STOP_STRATEGY: stop when one dataset has been exhausted
4849
- ALL_STOP_STRATEGY: stop when all datasets have been exhausted
4950
- RESTART_STRATEGY: restart the dataset when it has been exhausted
51+
randomize_epochs: if True, each pass through a finite mixture component uses an
52+
independent permutation of its samples; if False, the component is accessed in
53+
natural order via ``raw_idx % length``. Takes effect only for finite components
54+
under ``RESTART_STRATEGY`` or ``ALL_STOP_STRATEGY``; under ``FIRST_STOP_STRATEGY``
55+
no component completes more than one pass, so there is no second epoch to permute.
5056
key: random key for datasets sampling
5157
"""
5258

@@ -57,6 +63,7 @@ def __init__(
5763
block_size: int,
5864
*,
5965
randomize_blocks: bool = True,
66+
randomize_epochs: bool = False,
6067
key: PRNGKeyArray | int,
6168
stop_strategy: str = StopStrategy.RESTART_STRATEGY,
6269
):
@@ -94,6 +101,7 @@ def __init__(
94101
raise ValueError(f"Block size must be at most 2^16, got {block_size}")
95102

96103
self.randomize_blocks = randomize_blocks
104+
self.randomize_epochs = randomize_epochs
97105

98106
# this stupid dance is to ensure that the key is on CPU so we don't end up with weird device placement issues
99107
# in recent JAX.
@@ -255,7 +263,7 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[T]:
255263
batch_futures.append(future_from_value([]))
256264
else:
257265
dataset = self._dataset_of_id(dataset_id)
258-
indices_for_dataset = await self._remap_indices(dataset, indices_for_dataset)
266+
indices_for_dataset = await self._remap_indices(dataset, indices_for_dataset, dataset_id)
259267
batch_futures.append(dataset.get_batch(indices_for_dataset))
260268

261269
batches = await asyncio.gather(*batch_futures)
@@ -279,14 +287,12 @@ async def getitem_async(self, index: int) -> T:
279287
dataset_id, dataset_index = self._index_into_dataset_for_id(permuted_ids[index], block_id)
280288

281289
dataset = self._dataset_of_id(dataset_id)
282-
dataset_index = (await self._remap_indices(dataset, [dataset_index]))[0]
290+
dataset_index = (await self._remap_indices(dataset, [dataset_index], dataset_id))[0]
283291

284292
return await dataset.getitem_async(dataset_index)
285293

286-
async def _remap_indices(self, ds, indices_into_ds):
287-
"""
288-
Handles wrap around for datasets that have finite length
289-
"""
294+
async def _remap_indices(self, ds, indices_into_ds, dataset_id: int):
295+
"""Handles wrap around for datasets that have finite length."""
290296
if self.stop_strategy in [StopStrategy.RESTART_STRATEGY, StopStrategy.ALL_STOP_STRATEGY]:
291297
if ds.is_finite():
292298
length_of_dataset = await ds.async_len()
@@ -295,7 +301,10 @@ async def _remap_indices(self, ds, indices_into_ds):
295301
"MixtureDataset in RESTART_STRATEGY encountered an empty finite dataset "
296302
"(`async_len()` returned 0). Restart strategy does not support empty datasets."
297303
)
298-
indices_into_ds = [idx % length_of_dataset for idx in indices_into_ds]
304+
if self.randomize_epochs:
305+
indices_into_ds = self._apply_epoch_permutation(dataset_id, length_of_dataset, indices_into_ds)
306+
else:
307+
indices_into_ds = [idx % length_of_dataset for idx in indices_into_ds]
299308

300309
return indices_into_ds
301310

@@ -304,6 +313,22 @@ async def _remap_indices(self, ds, indices_into_ds):
304313

305314
raise ValueError(f"Unknown stop strategy: {self.stop_strategy}")
306315

316+
def _apply_epoch_permutation(self, dataset_id: int, length: int, indices_into_ds: Sequence[int]) -> list[int]:
317+
raw = np.asarray(indices_into_ds, dtype=np.int64)
318+
epochs = raw // length
319+
in_epoch = raw % length
320+
out = np.empty_like(in_epoch)
321+
# A batch may straddle an epoch boundary; each epoch uses its own permutation.
322+
for epoch in np.unique(epochs).tolist():
323+
mask = epochs == epoch
324+
perm = self._get_epoch_permutation(dataset_id, int(epoch), length)
325+
out[mask] = perm(in_epoch[mask])
326+
return [int(x) for x in out]
327+
328+
@functools.lru_cache(maxsize=128)
329+
def _get_epoch_permutation(self, dataset_id: int, epoch: int, length: int) -> Permutation:
330+
return _compute_epoch_assignment(dataset_id, epoch, length, self.key)
331+
307332
def _set_finiteness_cache(self, finite_length: int | None) -> int | None:
308333
self._cached_finite_length = finite_length
309334
self._is_finite_cache = finite_length is not None
@@ -503,6 +528,14 @@ def _compute_block_assignment(base_ids, index, key):
503528
return permuted_ids
504529

505530

531+
def _compute_epoch_assignment(dataset_id: int, epoch: int, length: int, key: PRNGKeyArray) -> Permutation:
532+
with local_cpu_mesh():
533+
sub_key = jax.random.fold_in(key, dataset_id)
534+
epoch_key = jax.random.fold_in(sub_key, epoch)
535+
epoch_key = jax.device_put(jax.device_get(epoch_key))
536+
return Permutation.make("feistel", length, epoch_key)
537+
538+
506539
def rescale_mixture_schedule_for_batch_schedule(
507540
mixture_schedule: Sequence[Tuple[int, dict[str, float]]], batch_schedule: BatchSchedule
508541
) -> List[Tuple[int, dict[str, float]]]:

lib/levanter/tests/test_mixture.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,13 @@ async def test_mixture_dataset_remap_indices():
203203
dses = datasets()
204204
mixture_ds = MixtureDataset(dses, weights(), block_size=10, key=key())
205205

206-
remapped_indices = await mixture_ds._remap_indices(dses["ds1"], [0, 1, 2])
206+
remapped_indices = await mixture_ds._remap_indices(dses["ds1"], [0, 1, 2], 0)
207207
assert len(remapped_indices) == 3
208208
assert remapped_indices == [0, 1, 2]
209209

210210
# check wrap around
211211
len_ds1 = await dses["ds1"].async_len()
212-
remapped_indices = await mixture_ds._remap_indices(dses["ds1"], [len_ds1 - 1, len_ds1, len_ds1 + 1])
212+
remapped_indices = await mixture_ds._remap_indices(dses["ds1"], [len_ds1 - 1, len_ds1, len_ds1 + 1], 0)
213213
assert len(remapped_indices) == 3
214214

215215
assert remapped_indices == [len_ds1 - 1, 0, 1]
@@ -266,6 +266,36 @@ async def test_mixture_dataset_randomizes_blocks():
266266
assert not np.all(block_assignment_1 == block_assignment_3), "Block assignments should be randomized"
267267

268268

269+
@pytest.mark.asyncio
270+
async def test_mixture_dataset_randomize_epochs_permutes_finite_component():
271+
"""Each pass through a finite component is its own permutation when randomize_epochs=True."""
272+
finite_length = 8
273+
finite = ListAsyncDataset(list(range(finite_length))) # value at index k is k
274+
dses = {"finite": finite, "infinite": InfiniteCounterDataset()}
275+
bs = 2 * finite_length
276+
weights = {"finite": 0.5, "infinite": 0.5}
277+
278+
# randomize_blocks=False + finite registered first ⇒ positions [0, L) of every block are finite.
279+
async def finite_orderings(ds: MixtureDataset, num_epochs: int) -> list[list[int]]:
280+
return [list(await ds.get_batch(list(range(e * bs, e * bs + finite_length)))) for e in range(num_epochs)]
281+
282+
shuffled = MixtureDataset(dses, weights, block_size=bs, key=key(), randomize_blocks=False, randomize_epochs=True)
283+
epochs = await finite_orderings(shuffled, num_epochs=3)
284+
285+
expected = list(range(finite_length))
286+
for i, ordering in enumerate(epochs):
287+
assert sorted(ordering) == expected, f"Epoch {i} is not a permutation of [0, L): {ordering}"
288+
289+
distinct = {tuple(o) for o in epochs}
290+
assert len(distinct) >= 2, f"Expected per-epoch orderings to differ, got {epochs}"
291+
assert epochs[0] != expected, f"Epoch 0 should not be the identity order, got {epochs[0]}"
292+
293+
baseline = MixtureDataset(dses, weights, block_size=bs, key=key(), randomize_blocks=False, randomize_epochs=False)
294+
baseline_epochs = await finite_orderings(baseline, num_epochs=3)
295+
for i, ordering in enumerate(baseline_epochs):
296+
assert ordering == expected, f"Default (randomize_epochs=False) epoch {i}: {ordering}"
297+
298+
269299
@pytest.mark.asyncio
270300
async def test_mixture_dataset_samples_all_elements():
271301
mixture_ds = MixtureDataset(datasets(), weights(), block_size=10, key=key())

0 commit comments

Comments
 (0)