Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions lib/levanter/src/levanter/store/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union

Expand Down Expand Up @@ -43,6 +44,7 @@
logger = pylogging.getLogger(__name__)

LEDGER_FILE_NAME = "shard_ledger.json"
CONSOLIDATE_DATA_SIZE_WORKERS = 32

DEFAULT_LOG_LEVEL = pylogging.INFO
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
Expand Down Expand Up @@ -430,7 +432,16 @@ def consolidate_shard_caches(

shard_ledgers = [CacheLedger.load(p, metadata) for p in shard_cache_paths]

for shard_path, ledger in zip(shard_cache_paths, shard_ledgers):
# Parallel: open each TreeStore to read data_size (dominates wall time on remote storage)
def _get_data_sizes(shard_path):
store = TreeStore.open(exemplar, shard_path, mode="r", cache_metadata=True)
return jax.tree.map(lambda x: x.data_size, store.tree)

with ThreadPoolExecutor(max_workers=CONSOLIDATE_DATA_SIZE_WORKERS) as executor:
per_shard_sizes = list(executor.map(_get_data_sizes, shard_cache_paths))

# Serial: accumulate row_offset and data_offset_tree (order-dependent)
for shard_path, ledger, this_offsets in zip(shard_cache_paths, shard_ledgers, per_shard_sizes):
shard_name = os.path.basename(shard_path)
shard_info.append(
{
Expand All @@ -442,9 +453,6 @@ def consolidate_shard_caches(
}
)
total_rows += ledger.total_num_rows

this_cache = TreeStore.open(exemplar, shard_path, mode="r", cache_metadata=True)
this_offsets = jax.tree.map(lambda x: x.data_size, this_cache.tree)
data_offset_tree = jax.tree.map(operator.add, data_offset_tree, this_offsets)

TreeStore.open(exemplar, output_path, mode="w", cache_metadata=True)
Expand Down
47 changes: 46 additions & 1 deletion tests/test_consolidate_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
import jax
import numpy as np

from levanter.store.cache import CacheLedger, _consolidate_metadata, _expose_cache_rows, _extend_cache_with_other_cache
from levanter.store.cache import (
CacheLedger,
_consolidate_metadata,
_expose_cache_rows,
_extend_cache_with_other_cache,
consolidate_shard_caches,
)
from levanter.store.tree_store import TreeStore

NUM_SHARDS = 8
Expand Down Expand Up @@ -107,3 +113,42 @@ def make_row(shard_index):
merged = _build_and_consolidate(EXEMPLAR_SHAPED, make_row)
row = merged[0]
assert row["spans"].shape == (3, 2)


def _build_shard_cache(shard_path: str, exemplar, rows: list[dict]) -> None:
"""Build a shard cache directory with data and a serialized ledger."""
store = TreeStore.open(exemplar, shard_path, mode="w", cache_metadata=True)
store.extend(rows)
_expose_cache_rows(shard_path, exemplar, len(rows))
ledger = CacheLedger(
total_num_rows=len(rows),
shard_rows={os.path.basename(shard_path): len(rows)},
is_finished=True,
finished_shards=[os.path.basename(shard_path)],
field_counts={},
)
ledger._serialize_and_commit(shard_path)


def test_consolidate_shard_caches_end_to_end():
"""Call consolidate_shard_caches directly, exercising the threaded pre-pass and Zephyr data copy."""
with tempfile.TemporaryDirectory(prefix="levanter-test-consolidate-e2e-") as tmpdir:
shard_paths = []
for i in range(NUM_SHARDS):
shard_path = os.path.join(tmpdir, f"shard_{i}")
rows = [{"input_ids": np.full((ROW_WIDTH,), i, dtype=np.int32)} for _ in range(ROWS_PER_SHARD)]
_build_shard_cache(shard_path, EXEMPLAR_FLAT, rows)
shard_paths.append(shard_path)

dest_path = os.path.join(tmpdir, "merged")
ledger = consolidate_shard_caches(shard_paths, dest_path, EXEMPLAR_FLAT)

assert ledger.total_num_rows == NUM_SHARDS * ROWS_PER_SHARD
assert ledger.is_finished

merged = TreeStore.open(EXEMPLAR_FLAT, dest_path, mode="r", cache_metadata=True)
assert len(merged) == NUM_SHARDS * ROWS_PER_SHARD

for i in range(NUM_SHARDS):
row = merged[i * ROWS_PER_SHARD]
assert row["input_ids"][0] == i, f"shard {i} data mismatch"
Loading