Skip to content

Commit 8a465b6

Browse files
levanter: parallelize build_caches over components (#5388)
* parallelize `LmDataConfig.build_caches` since sequential GCS round-trips dominated startup (~40 min for ~100 components in the Datakit Testbed before the first training step) * run per-component work in a `ThreadPoolExecutor` with `max_workers=min(32, len(items))`; work is GCS-metadata-bound (ledger reads, per-shard `ShardedTreeCache.__init__`) so threads fit [^1] * refactor the loop body into a `_build_one` helper returning `(name, cache_or_None)`; pre-filter eligible components (skip zero-weight train, `DirectDatasetComponent`, raise on unsupported types) before scheduling, then post-filter `None` results when keying the result dict * wrap the executor in `rigging.timing.log_time` so total wall time per `build_caches[<split>]` lands in the logs * skip and exception semantics unchanged — one bad component still fails the whole build * add unit tests in `lib/levanter/tests/test_text.py` * `test_build_caches_returns_all_components_in_parallel` — 4-component build, asserts the result dict is keyed by name with the right cache contents * `test_build_caches_propagates_exception_from_one_component` — mixed good/bad pair must raise so errors aren't swallowed by `pool.map` [^1]: cap of 32 avoids hammering GCS on very large component lists. --------- Co-authored-by: Rafal Wojdyla <ravwojdyla@gmail.com>
1 parent a4a43fa commit 8a465b6

2 files changed

Lines changed: 85 additions & 28 deletions

File tree

lib/levanter/src/levanter/data/text/datasets.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
import os
99
from collections.abc import Callable, Mapping, Sequence
10+
from concurrent.futures import ThreadPoolExecutor
1011
from dataclasses import dataclass
1112
from functools import cached_property
1213
from typing import Literal, NotRequired, TypeAlias, TypeVar, TypedDict
@@ -18,6 +19,7 @@
1819
from draccus import ChoiceRegistry, field
1920
from haliax import Axis
2021
from jaxtyping import PRNGKeyArray
22+
from rigging.timing import log_time
2123

2224
import levanter
2325
from levanter.data import AsyncDataset
@@ -861,58 +863,78 @@ def validation_grug_sets(self, *, seq_len: int) -> Mapping[str, AsyncDataset[Gru
861863
return self._validation_datasets_unwrapped(Pos)
862864

863865
def build_caches(self, split: str) -> dict[str, TreeCache[dict]]:
864-
caches: dict[str, TreeCache[dict]] = {}
866+
items: list[tuple[str, "DatasetComponent"]] = []
865867
for name, component in self.components.items():
866868
if split == "train" and not self._has_nonzero_weight(name):
867869
continue
868-
869870
if isinstance(component, DirectDatasetComponent):
870871
continue
871-
872872
if not isinstance(component, DatasetComponent):
873873
raise ValueError(f"Unsupported component type for {name}: {type(component)}")
874-
874+
items.append((name, component))
875+
876+
if not items:
877+
return {}
878+
879+
# Loads are pure GCS metadata reads and parallelize cleanly. Builds may
880+
# enter `_distributed_build_cache`, which uses unidentified jax
881+
# collectives paired across processes by dispatch order — running
882+
# multiple of those concurrently can cross-wire status broadcasts or
883+
# hang. Classify each component in the pool, then build any misses
884+
# serially in the original component order.
885+
def _load_or_defer(
886+
item: tuple[str, "DatasetComponent"],
887+
) -> tuple[str, TreeCache[dict] | None, tuple[str, ShardedDataSource, LmDatasetFormatBase] | None]:
888+
name, component = item
875889
cache_root = _component_cache_dir(name, component, self.cache_dir)
890+
cache_path = os.path.join(cache_root, split)
876891
source = component.source
877892

878893
if source is None:
879894
try:
880-
caches[name] = load_lm_dataset_cache(
881-
os.path.join(cache_root, split), component.format, self.the_tokenizer, self.enforce_eos
882-
)
895+
cache = load_lm_dataset_cache(cache_path, component.format, self.the_tokenizer, self.enforce_eos)
883896
except FileNotFoundError:
884897
raise ValueError(f"No source and no cache found for component {name} split {split}")
885-
continue
898+
return name, cache, None
886899

887900
shard_source = source.get_shard_source(split)
901+
cache_exists = fsspec_utils.exists(cache_path)
902+
888903
if shard_source is None:
889-
cache_path = os.path.join(cache_root, split)
890-
if not fsspec_utils.exists(cache_path):
904+
if not cache_exists:
891905
logger.warning(f"No source for {name} in {split} split and no cache at {cache_path}, skipping")
892-
continue
893-
caches[name] = load_lm_dataset_cache(
894-
cache_path, component.format, self.the_tokenizer, self.enforce_eos
895-
)
896-
continue
906+
return name, None, None
907+
cache = load_lm_dataset_cache(cache_path, component.format, self.the_tokenizer, self.enforce_eos)
908+
return name, cache, None
897909

898-
cache_path = os.path.join(cache_root, split)
899910
if not self.auto_build_caches:
900-
if not fsspec_utils.exists(cache_path):
911+
if not cache_exists:
901912
raise FileNotFoundError(f"Cache not found at {cache_path} and auto_build_caches is disabled")
902-
caches[name] = load_lm_dataset_cache(
903-
cache_path, component.format, self.the_tokenizer, self.enforce_eos
904-
)
905-
continue
913+
cache = load_lm_dataset_cache(cache_path, component.format, self.the_tokenizer, self.enforce_eos)
914+
return name, cache, None
915+
916+
if cache_exists:
917+
cache = load_lm_dataset_cache(cache_path, component.format, self.the_tokenizer, self.enforce_eos)
918+
return name, cache, None
919+
return name, None, (cache_path, shard_source, component.format)
906920

921+
caches: dict[str, TreeCache[dict]] = {}
922+
to_build: list[tuple[str, tuple[str, ShardedDataSource, LmDatasetFormatBase]]] = []
923+
max_workers = min(32, len(items))
924+
with (
925+
log_time(f"build_caches[{split}] over {len(items)} components"),
926+
ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="build_caches") as pool,
927+
):
928+
for name, cache, build_args in pool.map(_load_or_defer, items):
929+
if cache is not None:
930+
caches[name] = cache
931+
elif build_args is not None:
932+
to_build.append((name, build_args))
933+
934+
for name, (cache_path, shard_source, fmt) in to_build:
907935
caches[name] = build_lm_dataset_cache(
908-
cache_path,
909-
shard_source,
910-
component.format,
911-
self.the_tokenizer,
912-
self.cache_options,
913-
self.enforce_eos,
936+
cache_path, shard_source, fmt, self.the_tokenizer, self.cache_options, self.enforce_eos
914937
)
915-
916938
return caches
917939

918940
@property

lib/levanter/tests/test_text.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,3 +779,38 @@ def test_chat_dataset_build_and_pack(dummy_chat_data):
779779

780780
# loss_weight should coincide with assistant tokens only
781781
assert_loss_weight_matches_all_assistants(ex, tokenizer)
782+
783+
784+
# --- LmDataConfig.build_caches ---------------------------------------------
785+
786+
787+
def _write_prebuilt_jsonl(path: Path, records: list[dict]) -> None:
788+
with path.open("w") as f:
789+
for record in records:
790+
f.write(json.dumps(record) + "\n")
791+
792+
793+
def _prebuilt_train_component(jsonl_path: Path) -> DatasetComponent:
794+
return DatasetComponent(
795+
source=UrlDatasetSourceConfig(train_urls=[str(jsonl_path)], validation_urls=[]),
796+
format=PrebuiltLmDatasetFormat(),
797+
)
798+
799+
800+
def test_build_caches_propagates_exception_from_one_component(tmp_path):
801+
p_good = tmp_path / "good.jsonl"
802+
_write_prebuilt_jsonl(p_good, [{"input_ids": [1, 2, 3, 4]}])
803+
good = _prebuilt_train_component(p_good)
804+
bad = DatasetComponent(
805+
source=None,
806+
cache_dir=str(tmp_path / "bad_missing"),
807+
format=PrebuiltLmDatasetFormat(),
808+
)
809+
config = LmDataConfig(
810+
components={"good": good, "bad": bad},
811+
cache_dir=str(tmp_path / "caches"),
812+
tokenizer="passthrough",
813+
vocab_size=16,
814+
)
815+
with pytest.raises(ValueError, match="No source and no cache"):
816+
config.build_caches("train")

0 commit comments

Comments
 (0)