Skip to content

Commit f85f7ee

Browse files
committed
[codex] add in-memory virtual tree for sharded caches
1 parent dd81e6a commit f85f7ee

File tree

2 files changed

+291
-1
lines changed

2 files changed

+291
-1
lines changed

lib/levanter/src/levanter/store/cache.py

Lines changed: 268 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from levanter.data.dataset import AsyncDataset
3232
from levanter.utils.jax_utils import broadcast_one_to_all
33+
from levanter.utils.thread_utils import blocking_wait
3334

3435
from ..data._preprocessor import BatchProcessor, BatchResult, dict_from_record_batch
3536
from ..data.sharded_datasource import ShardedDataSource
@@ -160,6 +161,219 @@ def is_finished(self):
160161
return True
161162

162163

164+
class _VirtualRead:
165+
def __init__(self, read_async):
166+
self._read_async = read_async
167+
168+
def read(self):
169+
return self
170+
171+
def __await__(self):
172+
return self._read_async().__await__()
173+
174+
def result(self):
175+
return blocking_wait(self._read_async())
176+
177+
178+
class _ShardedArray:
179+
def __init__(self, arrays, sizes: list[int]):
180+
self._arrays = arrays
181+
self._sizes = sizes
182+
self._boundaries = _cumulative_offsets(sizes)
183+
184+
def __getitem__(self, item):
185+
return _VirtualRead(lambda: self._read(item))
186+
187+
async def _read(self, item):
188+
if isinstance(item, slice):
189+
start, stop, step = item.indices(self._boundaries[-1])
190+
if step != 1:
191+
values = await self._read(slice(start, stop))
192+
return values[::step]
193+
pieces = []
194+
for shard_index, local_slice in _split_slice_by_boundaries(start, stop, self._boundaries):
195+
pieces.append(await self._arrays[shard_index][local_slice].read())
196+
return _concatenate_or_empty(pieces)
197+
198+
index = item
199+
if index < 0:
200+
index += self._boundaries[-1]
201+
if index < 0 or index >= self._boundaries[-1]:
202+
raise IndexError("Index out of bounds")
203+
shard_index = bisect.bisect_right(self._boundaries, index) - 1
204+
local_index = index - self._boundaries[shard_index]
205+
return await self._arrays[shard_index][local_index].read()
206+
207+
208+
class _ShardedOffsets:
209+
def __init__(self, stores: list[JaggedArrayStore]):
210+
self._stores = stores
211+
self._num_rows = sum(store.num_rows for store in stores)
212+
self._data_sizes = [store.data_size for store in stores]
213+
214+
def __getitem__(self, item):
215+
return _VirtualRead(lambda: self._read(item))
216+
217+
async def _read(self, item):
218+
offsets = await self._full_offsets()
219+
return offsets[item]
220+
221+
async def _full_offsets(self):
222+
offset_reads = [store.offsets[0 : store.num_rows + 1].read() for store in self._stores]
223+
per_shard_offsets = await asyncio.gather(*offset_reads)
224+
adjusted_offsets = [np.asarray([self._num_rows], dtype=np.int64)]
225+
data_base = 0
226+
for offsets, data_size in zip(per_shard_offsets, self._data_sizes):
227+
offsets = np.asarray(offsets, dtype=np.int64)
228+
offsets[0] = 0
229+
adjusted_offsets.append(offsets[1:] + data_base)
230+
data_base += data_size
231+
return np.concatenate(adjusted_offsets)
232+
233+
234+
class _ShardedShapes:
235+
def __init__(self, stores: list[JaggedArrayStore]):
236+
self._stores = stores
237+
self._sizes = [store.num_rows for store in stores]
238+
self._boundaries = _cumulative_offsets(self._sizes)
239+
240+
def __getitem__(self, item):
241+
return _VirtualRead(lambda: self._read(item))
242+
243+
async def _read(self, item):
244+
if isinstance(item, slice):
245+
start, stop, step = item.indices(self._boundaries[-1])
246+
if step != 1:
247+
values = await self._read(slice(start, stop))
248+
return values[::step]
249+
pieces = []
250+
for shard_index, local_slice in _split_slice_by_boundaries(start, stop, self._boundaries):
251+
shapes = self._stores[shard_index].shapes
252+
assert shapes is not None
253+
pieces.append(await shapes[local_slice].read())
254+
return _concatenate_or_empty(pieces)
255+
256+
index = item
257+
if index < 0:
258+
index += self._boundaries[-1]
259+
if index < 0 or index >= self._boundaries[-1]:
260+
raise IndexError("Index out of bounds")
261+
shard_index = bisect.bisect_right(self._boundaries, index) - 1
262+
local_index = index - self._boundaries[shard_index]
263+
shapes = self._stores[shard_index].shapes
264+
assert shapes is not None
265+
return await shapes[local_index].read()
266+
267+
268+
class ShardedJaggedArrayStore:
269+
"""Virtual JaggedArrayStore backed by multiple shard-local stores."""
270+
271+
def __init__(self, stores: list[JaggedArrayStore]):
272+
if not stores:
273+
raise ValueError("ShardedJaggedArrayStore requires at least one store")
274+
self._stores = stores
275+
self.item_rank = stores[0].item_rank
276+
self.offsets = _ShardedOffsets(stores)
277+
self.data = _ShardedArray([store.data for store in stores], [store.data_size for store in stores])
278+
self.shapes = _ShardedShapes(stores) if stores[0].shapes is not None else None
279+
280+
@property
281+
def num_rows(self):
282+
return sum(store.num_rows for store in self._stores)
283+
284+
async def num_rows_async(self):
285+
return self.num_rows
286+
287+
@property
288+
def data_size(self):
289+
return sum(store.data_size for store in self._stores)
290+
291+
async def data_size_async(self):
292+
return self.data_size
293+
294+
def __len__(self):
295+
return self.num_rows
296+
297+
def __getitem__(self, item):
298+
if isinstance(item, slice):
299+
start, stop, step = item.indices(len(self))
300+
return self.get_batch_sync(list(range(start, stop, step)))
301+
shard_index, local_index = self._resolve_row(item)
302+
return self._stores[shard_index][local_index]
303+
304+
async def get_batch(self, indices: Sequence[int]) -> Sequence[np.ndarray]:
305+
shard_groups = _group_indices_by_shard(indices, self._row_boundaries())
306+
307+
results: list[None | np.ndarray] = [None] * len(indices)
308+
309+
async def fetch_shard(shard_index: int, items: list[tuple[int, int]]):
310+
local_indices = [local_index for _, local_index in items]
311+
batch = await self._stores[shard_index].get_batch(local_indices)
312+
for (position, _), value in zip(items, batch):
313+
results[position] = value
314+
315+
await asyncio.gather(*[fetch_shard(shard_index, items) for shard_index, items in shard_groups.items()])
316+
return results
317+
318+
def get_batch_sync(self, indices: Sequence[int]) -> Sequence[np.ndarray]:
319+
shard_groups = _group_indices_by_shard(indices, self._row_boundaries())
320+
results: list[None | np.ndarray] = [None] * len(indices)
321+
for shard_index, items in shard_groups.items():
322+
local_indices = [local_index for _, local_index in items]
323+
batch = self._stores[shard_index].get_batch_sync(local_indices)
324+
for (position, _), value in zip(items, batch):
325+
results[position] = value
326+
return results
327+
328+
def _resolve_row(self, index: int) -> tuple[int, int]:
329+
boundaries = self._row_boundaries()
330+
if index < 0:
331+
index += boundaries[-1]
332+
if index < 0 or index >= boundaries[-1]:
333+
raise IndexError("Index out of bounds")
334+
shard_index = bisect.bisect_right(boundaries, index) - 1
335+
return shard_index, index - boundaries[shard_index]
336+
337+
def _row_boundaries(self):
338+
return _cumulative_offsets([store.num_rows for store in self._stores])
339+
340+
341+
class ShardedTreeStore:
342+
"""Virtual TreeStore backed by multiple shard-local TreeStores."""
343+
344+
def __init__(self, stores: list[TreeStore]):
345+
if not stores:
346+
raise ValueError("ShardedTreeStore requires at least one store")
347+
self.path = stores[0].path
348+
self.mode = "r"
349+
self._stores = stores
350+
self.tree = jax.tree.map(
351+
lambda *leaves: ShardedJaggedArrayStore(list(leaves)), *[store.tree for store in stores]
352+
)
353+
354+
def __len__(self):
355+
return len(jax.tree.leaves(self.tree)[0])
356+
357+
async def async_len(self):
358+
return len(self)
359+
360+
def __getitem__(self, item):
361+
if isinstance(item, slice):
362+
start, stop, step = item.indices(len(self))
363+
return self.get_batch_sync(list(range(start, stop, step)))
364+
return jax.tree.map(lambda reader: reader[item], self.tree)
365+
366+
async def get_batch(self, indices) -> List[T_co]:
367+
grouped = jax.tree.map(lambda reader: reader.get_batch(indices), self.tree)
368+
leaves, structure = jax.tree.flatten(grouped)
369+
awaited_leaves = await asyncio.gather(*leaves)
370+
return [jax.tree.unflatten(structure, [leaf[i] for leaf in awaited_leaves]) for i in range(len(indices))]
371+
372+
def get_batch_sync(self, indices) -> List[T_co]:
373+
grouped = jax.tree.map(lambda reader: reader.get_batch_sync(indices), self.tree)
374+
return [jax.tree.map(lambda _, leaf: leaf[i], self.tree, grouped) for i in range(len(indices))]
375+
376+
163377
class ShardedTreeCache(AsyncDataset[T_co]):
164378
"""Reads across multiple shard caches without requiring a consolidation step.
165379
@@ -181,6 +395,11 @@ def __init__(self, shard_paths: list[str], exemplar: T_co, ledger: "CacheLedger"
181395
rows = ledger.shard_rows.get(shard_name, 0)
182396
self._cum_rows.append(self._cum_rows[-1] + rows)
183397
self._stores.append(TreeStore.open(exemplar, path, mode="r", cache_metadata=False))
398+
self._store = ShardedTreeStore(self._stores)
399+
400+
@property
401+
def store(self) -> ShardedTreeStore:
402+
return self._store
184403

185404
def _resolve_index(self, global_idx: int) -> tuple[int, int]:
186405
"""Return (shard_index, local_row) for a global row index."""
@@ -203,7 +422,10 @@ def __getitem__(self, item):
203422
shard_idx, local_idx = self._resolve_index(item)
204423
return self._stores[shard_idx][local_idx]
205424

206-
async def get_batch(self, indices: Sequence[int]) -> Sequence:
425+
async def get_batch(self, indices: Sequence[int] | slice) -> Sequence:
426+
if isinstance(indices, slice):
427+
indices = range(indices.start or 0, indices.stop or len(self), indices.step or 1)
428+
207429
# Group indices by shard, preserving original order
208430
shard_groups: dict[int, list[tuple[int, int]]] = {} # shard_idx -> [(position_in_output, local_idx)]
209431
for pos, global_idx in enumerate(indices):
@@ -248,6 +470,51 @@ def is_finished(self):
248470
return True
249471

250472

473+
def _cumulative_offsets(sizes: Sequence[int]) -> list[int]:
474+
offsets = [0]
475+
for size in sizes:
476+
offsets.append(offsets[-1] + size)
477+
return offsets
478+
479+
480+
def _split_slice_by_boundaries(start: int, stop: int, boundaries: Sequence[int]) -> list[tuple[int, slice]]:
481+
if start >= stop:
482+
return []
483+
pieces = []
484+
shard_index = bisect.bisect_right(boundaries, start) - 1
485+
while shard_index < len(boundaries) - 1 and start < stop:
486+
shard_start = boundaries[shard_index]
487+
shard_stop = boundaries[shard_index + 1]
488+
piece_stop = min(stop, shard_stop)
489+
if start < piece_stop:
490+
pieces.append((shard_index, slice(start - shard_start, piece_stop - shard_start)))
491+
start = piece_stop
492+
shard_index += 1
493+
return pieces
494+
495+
496+
def _concatenate_or_empty(pieces: Sequence[np.ndarray]) -> np.ndarray:
497+
if not pieces:
498+
return np.asarray([])
499+
if len(pieces) == 1:
500+
return np.asarray(pieces[0])
501+
return np.concatenate(pieces)
502+
503+
504+
def _group_indices_by_shard(indices: Sequence[int], boundaries: Sequence[int]) -> dict[int, list[tuple[int, int]]]:
505+
shard_groups: dict[int, list[tuple[int, int]]] = {}
506+
total_rows = boundaries[-1]
507+
for position, index in enumerate(indices):
508+
if index < 0:
509+
index += total_rows
510+
if index < 0 or index >= total_rows:
511+
raise IndexError("Index out of bounds")
512+
shard_index = bisect.bisect_right(boundaries, index) - 1
513+
local_index = index - boundaries[shard_index]
514+
shard_groups.setdefault(shard_index, []).append((position, local_index))
515+
return shard_groups
516+
517+
251518
@dataclass_json
252519
@dataclass
253520
class CacheLedger:

lib/levanter/tests/test_new_cache.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from zephyr.execution import ZephyrWorkerError
1111

1212
from levanter.data import BatchProcessor, ShardedDataSource, batched
13+
from levanter.data.packing import GreedyPrepackedDataset
1314
from levanter.data.sharded_datasource import TextUrlDataSource
1415
from levanter.store.cache import (
1516
CacheLedger,
@@ -334,6 +335,28 @@ def test_sharded_tree_cache_reads_across_shards():
334335
cache = ShardedTreeCache(shard_paths, exemplar, ledger)
335336

336337
assert len(cache) == 40
338+
assert cache.store.tree["data"].num_rows == 40
339+
assert cache.store.tree["data"].data_size == 400
340+
np.testing.assert_array_equal(
341+
cache.store.tree["data"].offsets[0:5].read().result(),
342+
np.asarray([40, 10, 20, 30, 40]),
343+
)
344+
np.testing.assert_array_equal(
345+
cache.store.tree["data"].data[95:105].read().result(),
346+
np.asarray([9, 9, 9, 9, 9, 10, 10, 10, 10, 10]),
347+
)
348+
packed = GreedyPrepackedDataset(
349+
cache.store.tree,
350+
max_length=25,
351+
max_segments_per_example=3,
352+
slice_strategy="raise",
353+
)
354+
packed_batch = packed.as_sync_dataset().get_batch([0])
355+
assert packed_batch[0][0]["data"].shape == (25,)
356+
np.testing.assert_array_equal(
357+
packed_batch[0][0]["data"],
358+
np.asarray([0] * 10 + [1] * 10 + [0] * 5),
359+
)
337360

338361
# Sequential read
339362
for i in range(40):

0 commit comments

Comments
 (0)