Skip to content

Commit 7e8b8a8

Browse files
ahmeda14960claude
andauthored
[levanter] Fix single-shard cache consolidation on R2/S3 (#4436)
## Summary - Skip TensorStore reopen for single-shard cache consolidation — copy shard contents directly via fsspec instead - Fixes the `Malformed StorageGeneration` error that breaks CoreWeave CI tokenization on R2-backed S3 - Adds regression test for single-shard consolidation path Fixes #4433 ## Test plan - [x] `uv run pytest -q tests/test_consolidate_metadata.py` — 4 passed (including new single-shard test) - [x] `uv run pytest -q tests/processing/tokenize/test_tokenize.py::test_tokenize_full_pipeline_integration -m slow` — 1 passed - [x] `./infra/pre-commit.py --fix` — OK - [ ] CoreWeave CI (`cw-ci-test`) passes on this branch 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent dd549bf commit 7e8b8a8

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

lib/levanter/src/levanter/store/cache.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import logging as pylogging
99
import operator
1010
import os
11+
import shutil
1112
import threading
1213
import time
1314
from concurrent.futures import ThreadPoolExecutor
@@ -429,14 +430,20 @@ def consolidate_shard_caches(
429430

430431
logger.info(f"Consolidating {len(shard_cache_paths)} shard caches into {output_path}")
431432

433+
shard_ledgers = [CacheLedger.load(p, metadata) for p in shard_cache_paths]
434+
435+
if len(shard_cache_paths) == 1:
436+
shard_path = shard_cache_paths[0]
437+
logger.info(f"Single shard cache detected; copying {shard_path} into {output_path} without reopening it.")
438+
_copy_tree_contents(shard_path, output_path)
439+
return _merge_ledgers(output_path, shard_cache_paths, shard_ledgers, metadata)
440+
432441
first_cache = TreeStore.open(exemplar, shard_cache_paths[0], mode="r", cache_metadata=True)
433442
data_offset_tree = jax.tree.map(lambda x: 0, first_cache.tree)
434443

435444
shard_info: list[dict] = []
436445
total_rows = 0
437446

438-
shard_ledgers = [CacheLedger.load(p, metadata) for p in shard_cache_paths]
439-
440447
# Parallel: open each TreeStore to read data_size (dominates wall time on remote storage)
441448
def _get_data_sizes(shard_path):
442449
store = TreeStore.open(exemplar, shard_path, mode="r", cache_metadata=True)
@@ -488,6 +495,30 @@ def _copy_shard(info: dict):
488495
return final_ledger
489496

490497

498+
def _copy_tree_contents(source_path: str, dest_path: str) -> None:
499+
src_fs, src_root = url_to_fs(source_path)
500+
dest_fs, dest_root = url_to_fs(dest_path)
501+
502+
dest_fs.makedirs(dest_root, exist_ok=True)
503+
entries = src_fs.find(src_root, withdirs=True, detail=True)
504+
505+
for src_file, info in entries.items():
506+
if info.get("type") == "directory":
507+
continue
508+
509+
rel_path = os.path.relpath(src_file, src_root)
510+
dest_file = os.path.join(dest_root, rel_path)
511+
dest_parent = os.path.dirname(dest_file)
512+
if dest_parent:
513+
dest_fs.makedirs(dest_parent, exist_ok=True)
514+
515+
if type(src_fs) is type(dest_fs):
516+
src_fs.copy(src_file, dest_file)
517+
else:
518+
with src_fs.open(src_file, "rb") as src_f, dest_fs.open(dest_file, "wb") as dest_f:
519+
shutil.copyfileobj(src_f, dest_f)
520+
521+
491522
def _merge_ledgers(
492523
output_path: str, shard_cache_paths: list[str], shard_ledgers: list[CacheLedger], metadata: CacheMetadata
493524
) -> CacheLedger:

tests/test_consolidate_metadata.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,22 @@ def test_consolidate_shard_caches_end_to_end():
152152
for i in range(NUM_SHARDS):
153153
row = merged[i * ROWS_PER_SHARD]
154154
assert row["input_ids"][0] == i, f"shard {i} data mismatch"
155+
156+
157+
def test_consolidate_shard_caches_single_shard():
158+
with tempfile.TemporaryDirectory(prefix="levanter-test-consolidate-single-") as tmpdir:
159+
shard_path = os.path.join(tmpdir, "shard_0")
160+
rows = [{"input_ids": np.full((ROW_WIDTH,), 7, dtype=np.int32)} for _ in range(ROWS_PER_SHARD)]
161+
_build_shard_cache(shard_path, EXEMPLAR_FLAT, rows)
162+
163+
dest_path = os.path.join(tmpdir, "merged")
164+
ledger = consolidate_shard_caches([shard_path], dest_path, EXEMPLAR_FLAT, copy_max_workers=1)
165+
166+
assert ledger.total_num_rows == ROWS_PER_SHARD
167+
assert ledger.is_finished
168+
assert ledger.shard_rows == {os.path.basename(shard_path): ROWS_PER_SHARD}
169+
assert ledger.finished_shards == [os.path.basename(shard_path)]
170+
171+
merged = TreeStore.open(EXEMPLAR_FLAT, dest_path, mode="r", cache_metadata=True)
172+
assert len(merged) == ROWS_PER_SHARD
173+
assert merged[0]["input_ids"][0] == 7

0 commit comments

Comments
 (0)