Skip to content

Commit d7c4111

Browse files
claude[bot]github-actions[bot]claudedlwhHelw150
authored
Add FlatMixture: virtual dataset concatenation with global shuffle (#4133)
FlatMixture logically concatenates multiple AsyncDatasets into a single index space and applies a FeistelPermutation for deterministic global shuffling. This lets small datasets be grouped as one MixtureDataset component without re-tokenizing into a merged cache. Children keep their own caches and FlatMixture resolves shuffled indices to (child, local_offset) via np.searchsorted on cumulative lengths. Fixes #4132 --------- Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: David Hall <dlwh@users.noreply.github.com> Co-authored-by: William Held <Helw150@users.noreply.github.com> Co-authored-by: William Held <will.held@openathena.ai>
1 parent 69135bf commit d7c4111

3 files changed

Lines changed: 154 additions & 3 deletions

File tree

lib/levanter/src/levanter/data/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from ._preprocessor import BatchProcessor
55
from .dataset import AsyncDataset, ListAsyncDataset, MappedAsyncDataset, SyncDataset
66
from .loader import DataLoader
7-
from .mixture import MixtureDataset, StopStrategy
7+
from .mixture import ConcatDataset, MixtureDataset, StopStrategy
88
from .permutation import BlockShufflingDataset, EraShufflingDataset, PermutationDataset
99
from .sharded_datasource import ShardedDataSource, datasource_from_hf, datasource_from_json, datasource_from_jsonl
1010
from .utils import batched
@@ -15,6 +15,7 @@
1515
"BatchProcessor",
1616
"BlockShufflingDataset",
1717
"DataLoader",
18+
"ConcatDataset",
1819
"EraShufflingDataset",
1920
"ListAsyncDataset",
2021
"MappedAsyncDataset",

lib/levanter/src/levanter/data/mixture.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import asyncio
55
import functools
6+
import logging
67
import warnings
78
from typing import List, Mapping, Sequence, Tuple, TypeVar
89

@@ -18,6 +19,8 @@
1819
from levanter.utils.index import Index
1920
from levanter.utils.thread_utils import blocking_wait, future_from_value
2021

22+
logger = logging.getLogger(__name__)
23+
2124

2225
T = TypeVar("T")
2326

@@ -395,6 +398,105 @@ def _dataset_of_id(self, id):
395398
return self.datasets[self.dataset_index[id]]
396399

397400

401+
class ConcatDataset(AsyncDataset[T]):
402+
"""Virtually concatenates multiple AsyncDatasets into a single index space.
403+
404+
ConcatDataset logically concatenates its children so that indices
405+
``[0, len(child_0))`` map to the first child, ``[len(child_0), len(child_0) +
406+
len(child_1))`` to the second, and so on. No data is copied.
407+
408+
All children must be finite. To shuffle the concatenated result, wrap with
409+
:class:`levanter.data.PermutationDataset`.
410+
411+
Args:
412+
datasets: Named child datasets to concatenate.
413+
"""
414+
415+
def __init__(
416+
self,
417+
datasets: Mapping[str, AsyncDataset[T]],
418+
):
419+
if len(datasets) == 0:
420+
raise ValueError("ConcatDataset requires at least one dataset")
421+
422+
for name, ds in datasets.items():
423+
if not ds.is_finite():
424+
raise ValueError(f"ConcatDataset requires all children to be finite, but '{name}' is not")
425+
426+
self.datasets = dict(datasets)
427+
self._names = list(self.datasets.keys())
428+
self._children = list(self.datasets.values())
429+
430+
self._cumulative_lengths: np.ndarray = np.array([])
431+
self._total_length: int | None = None
432+
433+
def is_finite(self) -> bool:
434+
return True
435+
436+
async def async_len(self) -> int:
437+
await self._ensure_initialized()
438+
assert self._total_length is not None
439+
return self._total_length
440+
441+
async def get_batch(self, indices: Sequence[int]) -> Sequence[T]:
442+
if not indices:
443+
return []
444+
445+
await self._ensure_initialized()
446+
447+
# Group by child dataset
448+
child_indices: list[list[int]] = [[] for _ in self._children]
449+
result_positions: list[list[int]] = [[] for _ in self._children]
450+
451+
cumulative = self._cumulative_lengths
452+
for batch_pos, idx in enumerate(indices):
453+
child_id = int(np.searchsorted(cumulative, idx, side="right"))
454+
local_offset = idx - (int(cumulative[child_id - 1]) if child_id > 0 else 0)
455+
child_indices[child_id].append(local_offset)
456+
result_positions[child_id].append(batch_pos)
457+
458+
# Fetch from each child in parallel
459+
batch_futures = []
460+
for child_id, idx_list in enumerate(child_indices):
461+
if len(idx_list) == 0:
462+
batch_futures.append(future_from_value([]))
463+
else:
464+
batch_futures.append(self._children[child_id].get_batch(idx_list))
465+
466+
child_batches = await asyncio.gather(*batch_futures)
467+
468+
# Reassemble
469+
result: list[T | None] = [None] * len(indices)
470+
for child_id, positions in enumerate(result_positions):
471+
for i, pos in enumerate(positions):
472+
result[pos] = child_batches[child_id][i]
473+
474+
return result # type: ignore
475+
476+
async def getitem_async(self, index: int) -> T:
477+
await self._ensure_initialized()
478+
479+
cumulative = self._cumulative_lengths
480+
child_id = int(np.searchsorted(cumulative, index, side="right"))
481+
local_offset = index - (int(cumulative[child_id - 1]) if child_id > 0 else 0)
482+
return await self._children[child_id].getitem_async(local_offset)
483+
484+
async def _ensure_initialized(self):
485+
if self._total_length is not None:
486+
return
487+
488+
lengths = await asyncio.gather(*[child.async_len() for child in self._children])
489+
cumulative = np.cumsum(lengths)
490+
self._cumulative_lengths = cumulative
491+
self._total_length = int(cumulative[-1])
492+
493+
if self._total_length == 0:
494+
raise ValueError("ConcatDataset total length is 0 — all children are empty")
495+
496+
def __repr__(self):
497+
return f"ConcatDataset({self._names})"
498+
499+
398500
def _compute_block_assignment(base_ids, index, key):
399501
rng = jax.random.fold_in(key, index)
400502
permuted_ids = jax.random.permutation(rng, base_ids)

lib/levanter/tests/test_mixture.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import numpy as np
88
import pytest
99

10-
from levanter.data import ListAsyncDataset, MixtureDataset
10+
from levanter.data import ListAsyncDataset, MixtureDataset, PermutationDataset
1111
from levanter.data.dataset import AsyncDataset
12-
from levanter.data.mixture import StopStrategy, rescale_mixture_schedule_for_batch_schedule
12+
from levanter.data.mixture import ConcatDataset, StopStrategy, rescale_mixture_schedule_for_batch_schedule
1313
from levanter.schedule import BatchSchedule, ScheduleStep
1414

1515

@@ -294,3 +294,51 @@ def test_rescale_mixture_schedule_for_batch_schedule():
294294
expected_schedule = [(0, {"ds1": 0.5, "ds2": 0.5}), (100, {"ds1": 0.2, "ds2": 0.8})]
295295

296296
assert rescaled_schedule == expected_schedule
297+
298+
299+
# --- ConcatDataset tests ---
300+
301+
302+
@pytest.mark.asyncio
303+
async def test_concat_dataset_getitem_consistent_with_get_batch():
304+
ds1 = ListAsyncDataset([1, 2, 3])
305+
ds2 = ListAsyncDataset([10, 20])
306+
concat = ConcatDataset({"a": ds1, "b": ds2})
307+
batch = await concat.get_batch([0, 1, 2, 3, 4])
308+
for i in range(5):
309+
assert await concat.getitem_async(i) == batch[i]
310+
311+
312+
@pytest.mark.asyncio
313+
async def test_concat_with_permutation_is_a_permutation():
314+
"""Every index maps to a unique element — no duplicates, no missing."""
315+
ds1 = ListAsyncDataset(list(range(50)))
316+
ds2 = ListAsyncDataset(list(range(50, 100)))
317+
concat = ConcatDataset({"a": ds1, "b": ds2})
318+
shuffled = PermutationDataset(concat, key=key())
319+
total = await shuffled.async_len()
320+
batch = await shuffled.get_batch(list(range(total)))
321+
assert sorted(batch) == list(range(100))
322+
323+
324+
@pytest.mark.asyncio
325+
async def test_concat_with_permutation_nests_in_mixture_dataset():
326+
"""ConcatDataset + PermutationDataset can be used as a MixtureDataset component."""
327+
ds1 = ListAsyncDataset(list(range(10)))
328+
ds2 = ListAsyncDataset(list(range(10, 20)))
329+
concat = ConcatDataset({"a": ds1, "b": ds2})
330+
shuffled = PermutationDataset(concat, key=key())
331+
332+
ds3 = ListAsyncDataset(list(range(100, 110)))
333+
mixture = MixtureDataset(
334+
{"flat": shuffled, "other": ds3},
335+
{"flat": 0.5, "other": 0.5},
336+
block_size=10,
337+
key=key(),
338+
randomize_blocks=False,
339+
)
340+
341+
batch = await mixture.get_batch(list(range(20)))
342+
assert len(batch) == 20
343+
all_values = set(range(20)) | set(range(100, 110))
344+
assert all(item in all_values for item in batch)

0 commit comments

Comments
 (0)