Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/levanter/src/levanter/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ._preprocessor import BatchProcessor
from .dataset import AsyncDataset, ListAsyncDataset, MappedAsyncDataset, SyncDataset
from .loader import DataLoader
from .mixture import MixtureDataset, StopStrategy
from .mixture import ConcatDataset, MixtureDataset, StopStrategy
from .permutation import BlockShufflingDataset, EraShufflingDataset, PermutationDataset
from .sharded_datasource import ShardedDataSource, datasource_from_hf, datasource_from_json, datasource_from_jsonl
from .utils import batched
Expand All @@ -15,6 +15,7 @@
"BatchProcessor",
"BlockShufflingDataset",
"DataLoader",
"ConcatDataset",
"EraShufflingDataset",
"ListAsyncDataset",
"MappedAsyncDataset",
Expand Down
102 changes: 102 additions & 0 deletions lib/levanter/src/levanter/data/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import asyncio
import functools
import logging
import warnings
from typing import List, Mapping, Sequence, Tuple, TypeVar

Expand All @@ -18,6 +19,8 @@
from levanter.utils.index import Index
from levanter.utils.thread_utils import blocking_wait, future_from_value

logger = logging.getLogger(__name__)


T = TypeVar("T")

Expand Down Expand Up @@ -395,6 +398,105 @@ def _dataset_of_id(self, id):
return self.datasets[self.dataset_index[id]]


class ConcatDataset(AsyncDataset[T]):
"""Virtually concatenates multiple AsyncDatasets into a single index space.

ConcatDataset logically concatenates its children so that indices
``[0, len(child_0))`` map to the first child, ``[len(child_0), len(child_0) +
len(child_1))`` to the second, and so on. No data is copied.

All children must be finite. To shuffle the concatenated result, wrap with
:class:`levanter.data.PermutationDataset`.

Args:
datasets: Named child datasets to concatenate.
"""

def __init__(
self,
datasets: Mapping[str, AsyncDataset[T]],
):
if len(datasets) == 0:
raise ValueError("ConcatDataset requires at least one dataset")

for name, ds in datasets.items():
if not ds.is_finite():
raise ValueError(f"ConcatDataset requires all children to be finite, but '{name}' is not")

self.datasets = dict(datasets)
self._names = list(self.datasets.keys())
self._children = list(self.datasets.values())

self._cumulative_lengths: np.ndarray = np.array([])
self._total_length: int | None = None

def is_finite(self) -> bool:
return True

async def async_len(self) -> int:
await self._ensure_initialized()
assert self._total_length is not None
return self._total_length

async def get_batch(self, indices: Sequence[int]) -> Sequence[T]:
if not indices:
return []

await self._ensure_initialized()

# Group by child dataset
child_indices: list[list[int]] = [[] for _ in self._children]
result_positions: list[list[int]] = [[] for _ in self._children]

cumulative = self._cumulative_lengths
for batch_pos, idx in enumerate(indices):
child_id = int(np.searchsorted(cumulative, idx, side="right"))
local_offset = idx - (int(cumulative[child_id - 1]) if child_id > 0 else 0)
child_indices[child_id].append(local_offset)
result_positions[child_id].append(batch_pos)

# Fetch from each child in parallel
batch_futures = []
for child_id, idx_list in enumerate(child_indices):
if len(idx_list) == 0:
batch_futures.append(future_from_value([]))
else:
batch_futures.append(self._children[child_id].get_batch(idx_list))

child_batches = await asyncio.gather(*batch_futures)

# Reassemble
result: list[T | None] = [None] * len(indices)
for child_id, positions in enumerate(result_positions):
for i, pos in enumerate(positions):
result[pos] = child_batches[child_id][i]

return result # type: ignore

async def getitem_async(self, index: int) -> T:
await self._ensure_initialized()

cumulative = self._cumulative_lengths
child_id = int(np.searchsorted(cumulative, index, side="right"))
local_offset = index - (int(cumulative[child_id - 1]) if child_id > 0 else 0)
return await self._children[child_id].getitem_async(local_offset)

async def _ensure_initialized(self):
if self._total_length is not None:
return

lengths = await asyncio.gather(*[child.async_len() for child in self._children])
cumulative = np.cumsum(lengths)
self._cumulative_lengths = cumulative
self._total_length = int(cumulative[-1])

if self._total_length == 0:
raise ValueError("ConcatDataset total length is 0 — all children are empty")

def __repr__(self):
return f"ConcatDataset({self._names})"


def _compute_block_assignment(base_ids, index, key):
rng = jax.random.fold_in(key, index)
permuted_ids = jax.random.permutation(rng, base_ids)
Expand Down
52 changes: 50 additions & 2 deletions lib/levanter/tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import numpy as np
import pytest

from levanter.data import ListAsyncDataset, MixtureDataset
from levanter.data import ListAsyncDataset, MixtureDataset, PermutationDataset
from levanter.data.dataset import AsyncDataset
from levanter.data.mixture import StopStrategy, rescale_mixture_schedule_for_batch_schedule
from levanter.data.mixture import ConcatDataset, StopStrategy, rescale_mixture_schedule_for_batch_schedule
from levanter.schedule import BatchSchedule, ScheduleStep


Expand Down Expand Up @@ -294,3 +294,51 @@ def test_rescale_mixture_schedule_for_batch_schedule():
expected_schedule = [(0, {"ds1": 0.5, "ds2": 0.5}), (100, {"ds1": 0.2, "ds2": 0.8})]

assert rescaled_schedule == expected_schedule


# --- ConcatDataset tests ---


@pytest.mark.asyncio
async def test_concat_dataset_getitem_consistent_with_get_batch():
ds1 = ListAsyncDataset([1, 2, 3])
ds2 = ListAsyncDataset([10, 20])
concat = ConcatDataset({"a": ds1, "b": ds2})
batch = await concat.get_batch([0, 1, 2, 3, 4])
for i in range(5):
assert await concat.getitem_async(i) == batch[i]


@pytest.mark.asyncio
async def test_concat_with_permutation_is_a_permutation():
"""Every index maps to a unique element — no duplicates, no missing."""
ds1 = ListAsyncDataset(list(range(50)))
ds2 = ListAsyncDataset(list(range(50, 100)))
concat = ConcatDataset({"a": ds1, "b": ds2})
shuffled = PermutationDataset(concat, key=key())
total = await shuffled.async_len()
batch = await shuffled.get_batch(list(range(total)))
assert sorted(batch) == list(range(100))


@pytest.mark.asyncio
async def test_concat_with_permutation_nests_in_mixture_dataset():
"""ConcatDataset + PermutationDataset can be used as a MixtureDataset component."""
ds1 = ListAsyncDataset(list(range(10)))
ds2 = ListAsyncDataset(list(range(10, 20)))
concat = ConcatDataset({"a": ds1, "b": ds2})
shuffled = PermutationDataset(concat, key=key())

ds3 = ListAsyncDataset(list(range(100, 110)))
mixture = MixtureDataset(
{"flat": shuffled, "other": ds3},
{"flat": 0.5, "other": 0.5},
block_size=10,
key=key(),
randomize_blocks=False,
)

batch = await mixture.get_batch(list(range(20)))
assert len(batch) == 20
all_values = set(range(20)) | set(range(100, 110))
assert all(item in all_values for item in batch)
Loading