From 3da360eb3efb9e297afc3d8ac6bdcc5f81440e28 Mon Sep 17 00:00:00 2001 From: yoblin <268258002+yoblin@users.noreply.github.com> Date: Tue, 24 Mar 2026 23:12:04 +0000 Subject: [PATCH] Use shared ts.Transaction for metadata consolidation (#4100) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace per-shard asyncio.run + per-shard transactions in consolidate_shard_caches with a single shared ts.Transaction that coalesces all metadata writes. Delete the now-unused per-shard _extend_cache_metadata_with_other. - O(num_shards) read-modify-write cycles → O(num_write_chunks) - Use info["ledger"].total_num_rows instead of redundant async_len() reads - Fix unsliced shapes write: shapes[:source_num_rows] vs bare shapes Closes #4100 Co-Authored-By: Claude Opus 4.6 (1M context) --- lib/levanter/src/levanter/store/cache.py | 108 ++++++++++------------ tests/test_consolidate_metadata.py | 109 +++++++++++++++++++++++ 2 files changed, 157 insertions(+), 60 deletions(-) create mode 100644 tests/test_consolidate_metadata.py diff --git a/lib/levanter/src/levanter/store/cache.py b/lib/levanter/src/levanter/store/cache.py index 2d1ae7ff72..bd32f68a0a 100644 --- a/lib/levanter/src/levanter/store/cache.py +++ b/lib/levanter/src/levanter/store/cache.py @@ -466,13 +466,8 @@ def _copy_shard(info: dict): verbose=False, ) - # do metadata serially b/c of write amplification concerns - for info in shard_info: - asyncio.run( - _extend_cache_metadata_with_other( - output_path, info["path"], exemplar, info["data_offset_tree"], info["row_offset"] - ) - ) + # Single shared transaction to coalesce metadata writes (see #4100, tensorstore#202) + asyncio.run(_consolidate_metadata(output_path, exemplar, shard_info)) final_ledger = _merge_ledgers(output_path, shard_cache_paths, shard_ledgers, metadata) # as a final step, set the total num rows in the final cache @@ -611,64 +606,57 @@ async def _copy_in_batches(dest_array, dest_offset, src_array, src_len, elems_pe await last_future -async def _extend_cache_metadata_with_other( - dest_path: str, source_path: str, exemplar: dict, data_offset_tree: PyTree[int], row_offset -) -> int: - try: - logger.info(f"Copying metadata from {source_path} to {dest_path}.") - dest = TreeStore.open(exemplar, dest_path, mode="a") - source = TreeStore.open(exemplar, source_path, mode="r", cache_metadata=True) +async def _consolidate_metadata(dest_path: str, exemplar: dict, shard_infos: list[dict]) -> None: + """Copy metadata (offsets + shapes) from all shards into dest using a single shared transaction. - source_num_rows = await source.async_len() + Replaces the old per-shard loop that committed a transaction per shard, causing + O(num_shards) read-modify-write cycles on the same zarr3 chunks (tensorstore#202). + """ + dest = TreeStore.open(exemplar, dest_path, mode="a") + start = time.monotonic() - async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArrayStore, data_offset: int): - if source_array.shapes is not None: - source_shapes = source_array.shapes - async with ts.Transaction() as txn: - dest_shapes = dest_array.shapes - assert dest_shapes is not None - out_end = row_offset + source_num_rows - shape_future = dest_shapes.with_transaction(txn)[row_offset:out_end].write(source_shapes) - - source_offsets = source_array.offsets[1 : source_num_rows + 1][ts.d[:].translate_to[0]] - source_offsets = _virtual_offset(source_offsets, data_offset) - - delay = 4 - while True: - try: - async with ts.Transaction() as txn: - dest_offsets = dest_array.offsets + delay = 4 + while True: + write_futures = [] + try: + async with ts.Transaction() as txn: + for info in shard_infos: + source = TreeStore.open(exemplar, info["path"], mode="r", cache_metadata=True) + source_num_rows = info["ledger"].total_num_rows + row_offset = info["row_offset"] + + for dest_array, source_array, data_offset in zip( + jax.tree.leaves(dest.tree), + jax.tree.leaves(source.tree), + jax.tree.leaves(info["data_offset_tree"]), + ): + if source_array.shapes is not None: + assert dest_array.shapes is not None + source_shapes = await source_array.shapes[:source_num_rows].read() + out_end = row_offset + source_num_rows + write_futures.append( + dest_array.shapes.with_transaction(txn)[row_offset:out_end].write(source_shapes) + ) + + source_offsets = await source_array.offsets[1 : source_num_rows + 1].read() + source_offsets = np.asarray(source_offsets) + data_offset out_end = 1 + row_offset + source_num_rows - offset_future = dest_offsets.with_transaction(txn)[row_offset + 1 : out_end].write( - source_offsets + write_futures.append( + dest_array.offsets.with_transaction(txn)[row_offset + 1 : out_end].write(source_offsets) ) - break - except ValueError as e: - if "Please reduce your request rate." in str(e): - logger.info("Rate limit exceeded. Retrying.") - await asyncio.sleep(delay) - delay *= 2 - if delay > 120: - raise - await offset_future - if source_array.shapes is not None: - await shape_future - - futures = jax.tree.map(_copy_one_array, dest.tree, source.tree, data_offset_tree) - await asyncio.gather(*jax.tree.leaves(futures)) - logger.info(f"Finished copying metadata from {source_path} to {dest_path}.") - return source_num_rows - except Exception as e: # noqa: BLE001 - logger.exception(f"Failed to copy metadata from {source_path} to {dest_path}: {e}") - raise - - -def _virtual_offset(base: ts.TensorStore, offset_amount): - async def do_read(domain: ts.IndexDomain, array: np.ndarray, read_params: ts.VirtualChunkedReadParameters): - array[...] = (await base[domain].read()) + offset_amount - - return ts.virtual_chunked(do_read, dtype=base.dtype, domain=base.domain, shape=base.shape) + await asyncio.gather(*write_futures) + elapsed = time.monotonic() - start + logger.info(f"Metadata consolidation complete: {len(shard_infos)} shards in {elapsed:.1f}s") + break + except ValueError as e: + if "Please reduce your request rate." not in str(e): + raise + logger.info(f"Rate limit exceeded during metadata consolidation. Retrying in {delay}s.") + await asyncio.sleep(delay) + delay *= 2 + if delay > 120: + raise def _sanitize_shard_name(name: str) -> str: diff --git a/tests/test_consolidate_metadata.py b/tests/test_consolidate_metadata.py new file mode 100644 index 0000000000..b302a74727 --- /dev/null +++ b/tests/test_consolidate_metadata.py @@ -0,0 +1,109 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +# Copyright The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for consolidated metadata copy using a shared ts.Transaction (#4100).""" + +import asyncio +import copy +import operator +import os +import tempfile + +import jax +import numpy as np + +from levanter.store.cache import CacheLedger, _consolidate_metadata, _expose_cache_rows, _extend_cache_with_other_cache +from levanter.store.tree_store import TreeStore + +NUM_SHARDS = 8 +ROWS_PER_SHARD = 32 +ROW_WIDTH = 16 + +# rank-1 only (no shapes metadata) +EXEMPLAR_FLAT = {"input_ids": np.array([0], dtype=np.int32)} + +# multi-field with a rank-2 leaf (triggers shapes metadata) +EXEMPLAR_SHAPED = { + "input_ids": np.array([0], dtype=np.int32), + "spans": np.zeros((0, 2), dtype=np.int32), +} + + +def _build_and_consolidate(exemplar, make_row) -> TreeStore: + """Build shards, copy data + metadata, return the merged store.""" + with tempfile.TemporaryDirectory(prefix="levanter-test-consolidate-") as tmpdir: + shard_root = os.path.join(tmpdir, "shards") + os.makedirs(shard_root) + + data_offset_tree = jax.tree.map(lambda _: 0, exemplar) + total_rows = 0 + shard_infos = [] + + for i in range(NUM_SHARDS): + shard_path = os.path.join(shard_root, f"shard_{i}") + store = TreeStore.open(exemplar, shard_path, mode="w", cache_metadata=True) + store.extend([make_row(i) for _ in range(ROWS_PER_SHARD)]) + + shard_infos.append( + { + "path": shard_path, + "row_offset": total_rows, + "data_offset_tree": copy.deepcopy(data_offset_tree), + "ledger": CacheLedger(total_num_rows=ROWS_PER_SHARD, shard_rows={}, is_finished=True), + } + ) + total_rows += ROWS_PER_SHARD + + this_offsets = jax.tree.map(lambda x: x.data_size, store.tree) + data_offset_tree = jax.tree.map(operator.add, data_offset_tree, this_offsets) + + dest_path = os.path.join(tmpdir, "dest") + TreeStore.open(exemplar, dest_path, mode="w", cache_metadata=True) + + for info in shard_infos: + asyncio.run( + _extend_cache_with_other_cache( + dest_path, + info["path"], + exemplar, + info["data_offset_tree"], + info["row_offset"], + ) + ) + asyncio.run(_consolidate_metadata(dest_path, exemplar, shard_infos)) + _expose_cache_rows(dest_path, exemplar, total_rows) + + merged = TreeStore.open(exemplar, dest_path, mode="r", cache_metadata=True) + assert len(merged) == NUM_SHARDS * ROWS_PER_SHARD + + for i, info in enumerate(shard_infos): + row = merged[info["row_offset"]] + assert row["input_ids"][0] == i, f"shard {i} data mismatch" + + return merged + + +def test_consolidate_metadata_flat(): + """Round-trip with a single rank-1 field (no shapes metadata).""" + + def make_row(shard_index): + return {"input_ids": np.full((ROW_WIDTH,), shard_index, dtype=np.int32)} + + _build_and_consolidate(EXEMPLAR_FLAT, make_row) + + +def test_consolidate_metadata_shaped(): + """Round-trip with multiple fields including rank-2 (exercises shapes metadata).""" + + def make_row(shard_index): + return { + "input_ids": np.full((ROW_WIDTH,), shard_index, dtype=np.int32), + "spans": np.full((3, 2), shard_index, dtype=np.int32), + } + + merged = _build_and_consolidate(EXEMPLAR_SHAPED, make_row) + row = merged[0] + assert row["spans"].shape == (3, 2)