🤖
Problem
When building pretraining data mixtures from many sources (e.g., Nemotron v2 + Common Pile + FinePDFs + FineTranslations), we run into two conflicting constraints:
- Reusability: Each dataset should be tokenized independently so others can mix and match without re-tokenizing.
- Mixture granularity:
MixtureDataset uses block-deterministic batching, so any component smaller than ~total_tokens / batch_size gets int(weight * bsz) == 0 and is never sampled. With bsz=64 and 14T total tokens, that threshold is ~220B — meaning datasets like NuminaMath (0.4B), PEPs (0.02B), or even PubMed (39B) can't be their own components.
The current workaround is merging small datasets into a single tokenization step (e.g., cp_nl concatenates 24 Common Pile datasets into one cache). This works but means:
- The merged cache can't be decomposed — if someone wants just peS2o, they need to re-tokenize it separately.
- Every different grouping produces a new cache, wasting storage.
Proposal: FlatMixture(AsyncDataset)
A FlatMixture logically concatenates multiple AsyncDataset children and applies a single global shuffle, behaving as if all the data were tokenized together into one dataset. It's an AsyncDataset itself, so it can be used as a component in a regular MixtureDataset (enabling nesting).
Key properties
- Virtual concatenation: no data copying. Children keep their own caches.
- Global shuffle: a single permutation (feistel or similar) over the full concatenated index space, equivalent to
cat * | shuffle.
- Deterministic + resumable:
global_index → permuted_index → (which_child, local_offset) is a pure function. Child mapping is binary search on cumulative lengths. No serialized state needed for resumption.
- Nestable: since it's an
AsyncDataset, a FlatMixture can be a child of another FlatMixture or a weighted component in MixtureDataset.
Interface sketch
class FlatMixture(AsyncDataset[T]):
def __init__(self, datasets: dict[str, AsyncDataset[T]], *, key: PRNGKeyArray):
# Concatenates index spaces, applies global shuffle
...
async def async_len(self) -> int:
return sum(child_lengths)
async def getitem_async(self, index: Index) -> T:
shuffled = self.permutation[index]
child, local_idx = self._resolve(shuffled) # binary search on cumulative lengths
return await self.children[child].getitem_async(local_idx)
Usage pattern
# Each dataset tokenized independently (reusable)
peS2o_tokenized = default_tokenize("peS2o", ...)
pubmed_tokenized = default_tokenize("pubmed", ...)
arxiv_tokenized = default_tokenize("arxiv", ...)
peps_tokenized = default_tokenize("peps", ...) # tiny, 0.02B
# Group into a FlatMixture — behaves like one big dataset
cp_nl = FlatMixture({"peS2o": peS2o, "pubmed": pubmed, "arxiv": arxiv, "peps": peps}, key=key)
# Use as a component in a regular weighted MixtureDataset
mixture = MixtureDataset(
datasets={"nemotron_cc": nemotron_cc, "cp_nl": cp_nl, "code": code},
weights={"nemotron_cc": 9000, "cp_nl": 518, "code": 1019},
...
)
Non-goals for v1
- Per-component weighting within FlatMixture (the weight IS the size — natural token-proportional sampling)
- Dynamic weight changes within FlatMixture (use
MixtureDataset for that)
🤖
Problem
When building pretraining data mixtures from many sources (e.g., Nemotron v2 + Common Pile + FinePDFs + FineTranslations), we run into two conflicting constraints:
MixtureDatasetuses block-deterministic batching, so any component smaller than~total_tokens / batch_sizegetsint(weight * bsz) == 0and is never sampled. With bsz=64 and 14T total tokens, that threshold is ~220B — meaning datasets like NuminaMath (0.4B), PEPs (0.02B), or even PubMed (39B) can't be their own components.The current workaround is merging small datasets into a single tokenization step (e.g.,
cp_nlconcatenates 24 Common Pile datasets into one cache). This works but means:Proposal:
FlatMixture(AsyncDataset)A
FlatMixturelogically concatenates multipleAsyncDatasetchildren and applies a single global shuffle, behaving as if all the data were tokenized together into one dataset. It's anAsyncDatasetitself, so it can be used as a component in a regularMixtureDataset(enabling nesting).Key properties
cat * | shuffle.global_index → permuted_index → (which_child, local_offset)is a pure function. Child mapping is binary search on cumulative lengths. No serialized state needed for resumption.AsyncDataset, aFlatMixturecan be a child of anotherFlatMixtureor a weighted component inMixtureDataset.Interface sketch
Usage pattern
Non-goals for v1
MixtureDatasetfor that)