1515from levanter .utils .jax_utils import local_cpu_mesh
1616
1717from levanter .data import AsyncDataset
18+ from levanter .data ._prp import Permutation
1819from levanter .schedule import BatchSchedule
1920from levanter .utils .index import Index
2021from levanter .utils .thread_utils import blocking_wait , future_from_value
@@ -47,6 +48,11 @@ class MixtureDataset(AsyncDataset[T]):
4748 - FIRST_STOP_STRATEGY: stop when one dataset has been exhausted
4849 - ALL_STOP_STRATEGY: stop when all datasets have been exhausted
4950 - RESTART_STRATEGY: restart the dataset when it has been exhausted
51+ randomize_epochs: if True, each pass through a finite mixture component uses an
52+ independent permutation of its samples; if False, the component is accessed in
53+ natural order via ``raw_idx % length``. Takes effect only for finite components
54+ under ``RESTART_STRATEGY`` or ``ALL_STOP_STRATEGY``; under ``FIRST_STOP_STRATEGY``
55+ no component completes more than one pass, so there is no second epoch to permute.
5056 key: random key for datasets sampling
5157 """
5258
@@ -57,6 +63,7 @@ def __init__(
5763 block_size : int ,
5864 * ,
5965 randomize_blocks : bool = True ,
66+ randomize_epochs : bool = False ,
6067 key : PRNGKeyArray | int ,
6168 stop_strategy : str = StopStrategy .RESTART_STRATEGY ,
6269 ):
@@ -94,6 +101,7 @@ def __init__(
94101 raise ValueError (f"Block size must be at most 2^16, got { block_size } " )
95102
96103 self .randomize_blocks = randomize_blocks
104+ self .randomize_epochs = randomize_epochs
97105
98106 # this stupid dance is to ensure that the key is on CPU so we don't end up with weird device placement issues
99107 # in recent JAX.
@@ -255,7 +263,7 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[T]:
255263 batch_futures .append (future_from_value ([]))
256264 else :
257265 dataset = self ._dataset_of_id (dataset_id )
258- indices_for_dataset = await self ._remap_indices (dataset , indices_for_dataset )
266+ indices_for_dataset = await self ._remap_indices (dataset , indices_for_dataset , dataset_id )
259267 batch_futures .append (dataset .get_batch (indices_for_dataset ))
260268
261269 batches = await asyncio .gather (* batch_futures )
@@ -279,14 +287,12 @@ async def getitem_async(self, index: int) -> T:
279287 dataset_id , dataset_index = self ._index_into_dataset_for_id (permuted_ids [index ], block_id )
280288
281289 dataset = self ._dataset_of_id (dataset_id )
282- dataset_index = (await self ._remap_indices (dataset , [dataset_index ]))[0 ]
290+ dataset_index = (await self ._remap_indices (dataset , [dataset_index ], dataset_id ))[0 ]
283291
284292 return await dataset .getitem_async (dataset_index )
285293
286- async def _remap_indices (self , ds , indices_into_ds ):
287- """
288- Handles wrap around for datasets that have finite length
289- """
294+ async def _remap_indices (self , ds , indices_into_ds , dataset_id : int ):
295+ """Handles wrap around for datasets that have finite length."""
290296 if self .stop_strategy in [StopStrategy .RESTART_STRATEGY , StopStrategy .ALL_STOP_STRATEGY ]:
291297 if ds .is_finite ():
292298 length_of_dataset = await ds .async_len ()
@@ -295,7 +301,10 @@ async def _remap_indices(self, ds, indices_into_ds):
295301 "MixtureDataset in RESTART_STRATEGY encountered an empty finite dataset "
296302 "(`async_len()` returned 0). Restart strategy does not support empty datasets."
297303 )
298- indices_into_ds = [idx % length_of_dataset for idx in indices_into_ds ]
304+ if self .randomize_epochs :
305+ indices_into_ds = self ._apply_epoch_permutation (dataset_id , length_of_dataset , indices_into_ds )
306+ else :
307+ indices_into_ds = [idx % length_of_dataset for idx in indices_into_ds ]
299308
300309 return indices_into_ds
301310
@@ -304,6 +313,22 @@ async def _remap_indices(self, ds, indices_into_ds):
304313
305314 raise ValueError (f"Unknown stop strategy: { self .stop_strategy } " )
306315
316+ def _apply_epoch_permutation (self , dataset_id : int , length : int , indices_into_ds : Sequence [int ]) -> list [int ]:
317+ raw = np .asarray (indices_into_ds , dtype = np .int64 )
318+ epochs = raw // length
319+ in_epoch = raw % length
320+ out = np .empty_like (in_epoch )
321+ # A batch may straddle an epoch boundary; each epoch uses its own permutation.
322+ for epoch in np .unique (epochs ).tolist ():
323+ mask = epochs == epoch
324+ perm = self ._get_epoch_permutation (dataset_id , int (epoch ), length )
325+ out [mask ] = perm (in_epoch [mask ])
326+ return [int (x ) for x in out ]
327+
328+ @functools .lru_cache (maxsize = 128 )
329+ def _get_epoch_permutation (self , dataset_id : int , epoch : int , length : int ) -> Permutation :
330+ return _compute_epoch_assignment (dataset_id , epoch , length , self .key )
331+
307332 def _set_finiteness_cache (self , finite_length : int | None ) -> int | None :
308333 self ._cached_finite_length = finite_length
309334 self ._is_finite_cache = finite_length is not None
@@ -503,6 +528,14 @@ def _compute_block_assignment(base_ids, index, key):
503528 return permuted_ids
504529
505530
531+ def _compute_epoch_assignment (dataset_id : int , epoch : int , length : int , key : PRNGKeyArray ) -> Permutation :
532+ with local_cpu_mesh ():
533+ sub_key = jax .random .fold_in (key , dataset_id )
534+ epoch_key = jax .random .fold_in (sub_key , epoch )
535+ epoch_key = jax .device_put (jax .device_get (epoch_key ))
536+ return Permutation .make ("feistel" , length , epoch_key )
537+
538+
506539def rescale_mixture_schedule_for_batch_schedule (
507540 mixture_schedule : Sequence [Tuple [int , dict [str , float ]]], batch_schedule : BatchSchedule
508541) -> List [Tuple [int , dict [str , float ]]]:
0 commit comments