Skip to content

🤖 FlatMixture: virtual dataset concatenation without re-tokenizing #4132

@Helw150

Description

@Helw150

🤖

Problem

When building pretraining data mixtures from many sources (e.g., Nemotron v2 + Common Pile + FinePDFs + FineTranslations), we run into two conflicting constraints:

  1. Reusability: Each dataset should be tokenized independently so others can mix and match without re-tokenizing.
  2. Mixture granularity: MixtureDataset uses block-deterministic batching, so any component smaller than ~total_tokens / batch_size gets int(weight * bsz) == 0 and is never sampled. With bsz=64 and 14T total tokens, that threshold is ~220B — meaning datasets like NuminaMath (0.4B), PEPs (0.02B), or even PubMed (39B) can't be their own components.

The current workaround is merging small datasets into a single tokenization step (e.g., cp_nl concatenates 24 Common Pile datasets into one cache). This works but means:

  • The merged cache can't be decomposed — if someone wants just peS2o, they need to re-tokenize it separately.
  • Every different grouping produces a new cache, wasting storage.

Proposal: FlatMixture(AsyncDataset)

A FlatMixture logically concatenates multiple AsyncDataset children and applies a single global shuffle, behaving as if all the data were tokenized together into one dataset. It's an AsyncDataset itself, so it can be used as a component in a regular MixtureDataset (enabling nesting).

Key properties

  • Virtual concatenation: no data copying. Children keep their own caches.
  • Global shuffle: a single permutation (feistel or similar) over the full concatenated index space, equivalent to cat * | shuffle.
  • Deterministic + resumable: global_index → permuted_index → (which_child, local_offset) is a pure function. Child mapping is binary search on cumulative lengths. No serialized state needed for resumption.
  • Nestable: since it's an AsyncDataset, a FlatMixture can be a child of another FlatMixture or a weighted component in MixtureDataset.

Interface sketch

class FlatMixture(AsyncDataset[T]):
    def __init__(self, datasets: dict[str, AsyncDataset[T]], *, key: PRNGKeyArray):
        # Concatenates index spaces, applies global shuffle
        ...

    async def async_len(self) -> int:
        return sum(child_lengths)

    async def getitem_async(self, index: Index) -> T:
        shuffled = self.permutation[index]
        child, local_idx = self._resolve(shuffled)  # binary search on cumulative lengths
        return await self.children[child].getitem_async(local_idx)

Usage pattern

# Each dataset tokenized independently (reusable)
peS2o_tokenized = default_tokenize("peS2o", ...)
pubmed_tokenized = default_tokenize("pubmed", ...)
arxiv_tokenized = default_tokenize("arxiv", ...)
peps_tokenized = default_tokenize("peps", ...)  # tiny, 0.02B

# Group into a FlatMixture — behaves like one big dataset
cp_nl = FlatMixture({"peS2o": peS2o, "pubmed": pubmed, "arxiv": arxiv, "peps": peps}, key=key)

# Use as a component in a regular weighted MixtureDataset
mixture = MixtureDataset(
    datasets={"nemotron_cc": nemotron_cc, "cp_nl": cp_nl, "code": code},
    weights={"nemotron_cc": 9000, "cp_nl": 518, "code": 1019},
    ...
)

Non-goals for v1

  • Per-component weighting within FlatMixture (the weight IS the size — natural token-proportional sampling)
  • Dynamic weight changes within FlatMixture (use MixtureDataset for that)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions