Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 48 additions & 60 deletions lib/levanter/src/levanter/store/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,13 +466,8 @@ def _copy_shard(info: dict):
verbose=False,
)

# do metadata serially b/c of write amplification concerns
for info in shard_info:
asyncio.run(
_extend_cache_metadata_with_other(
output_path, info["path"], exemplar, info["data_offset_tree"], info["row_offset"]
)
)
# Single shared transaction to coalesce metadata writes (see #4100, tensorstore#202)
asyncio.run(_consolidate_metadata(output_path, exemplar, shard_info))

final_ledger = _merge_ledgers(output_path, shard_cache_paths, shard_ledgers, metadata)
# as a final step, set the total num rows in the final cache
Expand Down Expand Up @@ -611,64 +606,57 @@ async def _copy_in_batches(dest_array, dest_offset, src_array, src_len, elems_pe
await last_future


async def _extend_cache_metadata_with_other(
dest_path: str, source_path: str, exemplar: dict, data_offset_tree: PyTree[int], row_offset
) -> int:
try:
logger.info(f"Copying metadata from {source_path} to {dest_path}.")
dest = TreeStore.open(exemplar, dest_path, mode="a")
source = TreeStore.open(exemplar, source_path, mode="r", cache_metadata=True)
async def _consolidate_metadata(dest_path: str, exemplar: dict, shard_infos: list[dict]) -> None:
"""Copy metadata (offsets + shapes) from all shards into dest using a single shared transaction.

source_num_rows = await source.async_len()
Replaces the old per-shard loop that committed a transaction per shard, causing
O(num_shards) read-modify-write cycles on the same zarr3 chunks (tensorstore#202).
"""
dest = TreeStore.open(exemplar, dest_path, mode="a")
start = time.monotonic()

async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArrayStore, data_offset: int):
if source_array.shapes is not None:
source_shapes = source_array.shapes
async with ts.Transaction() as txn:
dest_shapes = dest_array.shapes
assert dest_shapes is not None
out_end = row_offset + source_num_rows
shape_future = dest_shapes.with_transaction(txn)[row_offset:out_end].write(source_shapes)

source_offsets = source_array.offsets[1 : source_num_rows + 1][ts.d[:].translate_to[0]]
source_offsets = _virtual_offset(source_offsets, data_offset)

delay = 4
while True:
try:
async with ts.Transaction() as txn:
dest_offsets = dest_array.offsets
delay = 4
while True:
write_futures = []
try:
async with ts.Transaction() as txn:
for info in shard_infos:
source = TreeStore.open(exemplar, info["path"], mode="r", cache_metadata=True)
source_num_rows = info["ledger"].total_num_rows
row_offset = info["row_offset"]

for dest_array, source_array, data_offset in zip(
jax.tree.leaves(dest.tree),
jax.tree.leaves(source.tree),
jax.tree.leaves(info["data_offset_tree"]),
):
if source_array.shapes is not None:
assert dest_array.shapes is not None
source_shapes = await source_array.shapes[:source_num_rows].read()
out_end = row_offset + source_num_rows
write_futures.append(
dest_array.shapes.with_transaction(txn)[row_offset:out_end].write(source_shapes)
)

source_offsets = await source_array.offsets[1 : source_num_rows + 1].read()
source_offsets = np.asarray(source_offsets) + data_offset
out_end = 1 + row_offset + source_num_rows
offset_future = dest_offsets.with_transaction(txn)[row_offset + 1 : out_end].write(
source_offsets
write_futures.append(
dest_array.offsets.with_transaction(txn)[row_offset + 1 : out_end].write(source_offsets)
)
break
except ValueError as e:
if "Please reduce your request rate." in str(e):
logger.info("Rate limit exceeded. Retrying.")
await asyncio.sleep(delay)
delay *= 2
if delay > 120:
raise
await offset_future
if source_array.shapes is not None:
await shape_future

futures = jax.tree.map(_copy_one_array, dest.tree, source.tree, data_offset_tree)

await asyncio.gather(*jax.tree.leaves(futures))
logger.info(f"Finished copying metadata from {source_path} to {dest_path}.")
return source_num_rows
except Exception as e: # noqa: BLE001
logger.exception(f"Failed to copy metadata from {source_path} to {dest_path}: {e}")
raise


def _virtual_offset(base: ts.TensorStore, offset_amount):
async def do_read(domain: ts.IndexDomain, array: np.ndarray, read_params: ts.VirtualChunkedReadParameters):
array[...] = (await base[domain].read()) + offset_amount

return ts.virtual_chunked(do_read, dtype=base.dtype, domain=base.domain, shape=base.shape)
await asyncio.gather(*write_futures)
elapsed = time.monotonic() - start
logger.info(f"Metadata consolidation complete: {len(shard_infos)} shards in {elapsed:.1f}s")
break
except ValueError as e:
if "Please reduce your request rate." not in str(e):
raise
logger.info(f"Rate limit exceeded during metadata consolidation. Retrying in {delay}s.")
await asyncio.sleep(delay)
delay *= 2
if delay > 120:
raise


def _sanitize_shard_name(name: str) -> str:
Expand Down
109 changes: 109 additions & 0 deletions tests/test_consolidate_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

# Copyright The Levanter Authors
# SPDX-License-Identifier: Apache-2.0

"""Tests for consolidated metadata copy using a shared ts.Transaction (#4100)."""

import asyncio
import copy
import operator
import os
import tempfile

import jax
import numpy as np

from levanter.store.cache import CacheLedger, _consolidate_metadata, _expose_cache_rows, _extend_cache_with_other_cache
from levanter.store.tree_store import TreeStore

NUM_SHARDS = 8
ROWS_PER_SHARD = 32
ROW_WIDTH = 16

# rank-1 only (no shapes metadata)
EXEMPLAR_FLAT = {"input_ids": np.array([0], dtype=np.int32)}

# multi-field with a rank-2 leaf (triggers shapes metadata)
EXEMPLAR_SHAPED = {
"input_ids": np.array([0], dtype=np.int32),
"spans": np.zeros((0, 2), dtype=np.int32),
}


def _build_and_consolidate(exemplar, make_row) -> TreeStore:
"""Build shards, copy data + metadata, return the merged store."""
with tempfile.TemporaryDirectory(prefix="levanter-test-consolidate-") as tmpdir:
shard_root = os.path.join(tmpdir, "shards")
os.makedirs(shard_root)

data_offset_tree = jax.tree.map(lambda _: 0, exemplar)
total_rows = 0
shard_infos = []

for i in range(NUM_SHARDS):
shard_path = os.path.join(shard_root, f"shard_{i}")
store = TreeStore.open(exemplar, shard_path, mode="w", cache_metadata=True)
store.extend([make_row(i) for _ in range(ROWS_PER_SHARD)])

shard_infos.append(
{
"path": shard_path,
"row_offset": total_rows,
"data_offset_tree": copy.deepcopy(data_offset_tree),
"ledger": CacheLedger(total_num_rows=ROWS_PER_SHARD, shard_rows={}, is_finished=True),
}
)
total_rows += ROWS_PER_SHARD

this_offsets = jax.tree.map(lambda x: x.data_size, store.tree)
data_offset_tree = jax.tree.map(operator.add, data_offset_tree, this_offsets)

dest_path = os.path.join(tmpdir, "dest")
TreeStore.open(exemplar, dest_path, mode="w", cache_metadata=True)

for info in shard_infos:
asyncio.run(
_extend_cache_with_other_cache(
dest_path,
info["path"],
exemplar,
info["data_offset_tree"],
info["row_offset"],
)
)
asyncio.run(_consolidate_metadata(dest_path, exemplar, shard_infos))
_expose_cache_rows(dest_path, exemplar, total_rows)

merged = TreeStore.open(exemplar, dest_path, mode="r", cache_metadata=True)
assert len(merged) == NUM_SHARDS * ROWS_PER_SHARD

for i, info in enumerate(shard_infos):
row = merged[info["row_offset"]]
assert row["input_ids"][0] == i, f"shard {i} data mismatch"

return merged
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep merged store reads inside TemporaryDirectory scope

_build_and_consolidate returns a live TreeStore handle from inside a with tempfile.TemporaryDirectory(...) block, so the backing directory is deleted before callers use it. In test_consolidate_metadata_shaped, merged[0] is read after return; this can fail or become flaky when TensorStore needs to fetch uncached data from disk. This makes the new test unreliable across environments and cache behavior.

Useful? React with 👍 / 👎.



def test_consolidate_metadata_flat():
"""Round-trip with a single rank-1 field (no shapes metadata)."""

def make_row(shard_index):
return {"input_ids": np.full((ROW_WIDTH,), shard_index, dtype=np.int32)}

_build_and_consolidate(EXEMPLAR_FLAT, make_row)


def test_consolidate_metadata_shaped():
"""Round-trip with multiple fields including rank-2 (exercises shapes metadata)."""

def make_row(shard_index):
return {
"input_ids": np.full((ROW_WIDTH,), shard_index, dtype=np.int32),
"spans": np.full((3, 2), shard_index, dtype=np.int32),
}

merged = _build_and_consolidate(EXEMPLAR_SHAPED, make_row)
row = merged[0]
assert row["spans"].shape == (3, 2)
Loading