3030
3131from levanter .data .dataset import AsyncDataset
3232from levanter .utils .jax_utils import broadcast_one_to_all
33+ from levanter .utils .thread_utils import blocking_wait
3334
3435from ..data ._preprocessor import BatchProcessor , BatchResult , dict_from_record_batch
3536from ..data .sharded_datasource import ShardedDataSource
@@ -160,6 +161,219 @@ def is_finished(self):
160161 return True
161162
162163
164+ class _VirtualRead :
165+ def __init__ (self , read_async ):
166+ self ._read_async = read_async
167+
168+ def read (self ):
169+ return self
170+
171+ def __await__ (self ):
172+ return self ._read_async ().__await__ ()
173+
174+ def result (self ):
175+ return blocking_wait (self ._read_async ())
176+
177+
178+ class _ShardedArray :
179+ def __init__ (self , arrays , sizes : list [int ]):
180+ self ._arrays = arrays
181+ self ._sizes = sizes
182+ self ._boundaries = _cumulative_offsets (sizes )
183+
184+ def __getitem__ (self , item ):
185+ return _VirtualRead (lambda : self ._read (item ))
186+
187+ async def _read (self , item ):
188+ if isinstance (item , slice ):
189+ start , stop , step = item .indices (self ._boundaries [- 1 ])
190+ if step != 1 :
191+ values = await self ._read (slice (start , stop ))
192+ return values [::step ]
193+ pieces = []
194+ for shard_index , local_slice in _split_slice_by_boundaries (start , stop , self ._boundaries ):
195+ pieces .append (await self ._arrays [shard_index ][local_slice ].read ())
196+ return _concatenate_or_empty (pieces )
197+
198+ index = item
199+ if index < 0 :
200+ index += self ._boundaries [- 1 ]
201+ if index < 0 or index >= self ._boundaries [- 1 ]:
202+ raise IndexError ("Index out of bounds" )
203+ shard_index = bisect .bisect_right (self ._boundaries , index ) - 1
204+ local_index = index - self ._boundaries [shard_index ]
205+ return await self ._arrays [shard_index ][local_index ].read ()
206+
207+
208+ class _ShardedOffsets :
209+ def __init__ (self , stores : list [JaggedArrayStore ]):
210+ self ._stores = stores
211+ self ._num_rows = sum (store .num_rows for store in stores )
212+ self ._data_sizes = [store .data_size for store in stores ]
213+
214+ def __getitem__ (self , item ):
215+ return _VirtualRead (lambda : self ._read (item ))
216+
217+ async def _read (self , item ):
218+ offsets = await self ._full_offsets ()
219+ return offsets [item ]
220+
221+ async def _full_offsets (self ):
222+ offset_reads = [store .offsets [0 : store .num_rows + 1 ].read () for store in self ._stores ]
223+ per_shard_offsets = await asyncio .gather (* offset_reads )
224+ adjusted_offsets = [np .asarray ([self ._num_rows ], dtype = np .int64 )]
225+ data_base = 0
226+ for offsets , data_size in zip (per_shard_offsets , self ._data_sizes ):
227+ offsets = np .asarray (offsets , dtype = np .int64 )
228+ offsets [0 ] = 0
229+ adjusted_offsets .append (offsets [1 :] + data_base )
230+ data_base += data_size
231+ return np .concatenate (adjusted_offsets )
232+
233+
234+ class _ShardedShapes :
235+ def __init__ (self , stores : list [JaggedArrayStore ]):
236+ self ._stores = stores
237+ self ._sizes = [store .num_rows for store in stores ]
238+ self ._boundaries = _cumulative_offsets (self ._sizes )
239+
240+ def __getitem__ (self , item ):
241+ return _VirtualRead (lambda : self ._read (item ))
242+
243+ async def _read (self , item ):
244+ if isinstance (item , slice ):
245+ start , stop , step = item .indices (self ._boundaries [- 1 ])
246+ if step != 1 :
247+ values = await self ._read (slice (start , stop ))
248+ return values [::step ]
249+ pieces = []
250+ for shard_index , local_slice in _split_slice_by_boundaries (start , stop , self ._boundaries ):
251+ shapes = self ._stores [shard_index ].shapes
252+ assert shapes is not None
253+ pieces .append (await shapes [local_slice ].read ())
254+ return _concatenate_or_empty (pieces )
255+
256+ index = item
257+ if index < 0 :
258+ index += self ._boundaries [- 1 ]
259+ if index < 0 or index >= self ._boundaries [- 1 ]:
260+ raise IndexError ("Index out of bounds" )
261+ shard_index = bisect .bisect_right (self ._boundaries , index ) - 1
262+ local_index = index - self ._boundaries [shard_index ]
263+ shapes = self ._stores [shard_index ].shapes
264+ assert shapes is not None
265+ return await shapes [local_index ].read ()
266+
267+
268+ class ShardedJaggedArrayStore :
269+ """Virtual JaggedArrayStore backed by multiple shard-local stores."""
270+
271+ def __init__ (self , stores : list [JaggedArrayStore ]):
272+ if not stores :
273+ raise ValueError ("ShardedJaggedArrayStore requires at least one store" )
274+ self ._stores = stores
275+ self .item_rank = stores [0 ].item_rank
276+ self .offsets = _ShardedOffsets (stores )
277+ self .data = _ShardedArray ([store .data for store in stores ], [store .data_size for store in stores ])
278+ self .shapes = _ShardedShapes (stores ) if stores [0 ].shapes is not None else None
279+
280+ @property
281+ def num_rows (self ):
282+ return sum (store .num_rows for store in self ._stores )
283+
284+ async def num_rows_async (self ):
285+ return self .num_rows
286+
287+ @property
288+ def data_size (self ):
289+ return sum (store .data_size for store in self ._stores )
290+
291+ async def data_size_async (self ):
292+ return self .data_size
293+
294+ def __len__ (self ):
295+ return self .num_rows
296+
297+ def __getitem__ (self , item ):
298+ if isinstance (item , slice ):
299+ start , stop , step = item .indices (len (self ))
300+ return self .get_batch_sync (list (range (start , stop , step )))
301+ shard_index , local_index = self ._resolve_row (item )
302+ return self ._stores [shard_index ][local_index ]
303+
304+ async def get_batch (self , indices : Sequence [int ]) -> Sequence [np .ndarray ]:
305+ shard_groups = _group_indices_by_shard (indices , self ._row_boundaries ())
306+
307+ results : list [None | np .ndarray ] = [None ] * len (indices )
308+
309+ async def fetch_shard (shard_index : int , items : list [tuple [int , int ]]):
310+ local_indices = [local_index for _ , local_index in items ]
311+ batch = await self ._stores [shard_index ].get_batch (local_indices )
312+ for (position , _ ), value in zip (items , batch ):
313+ results [position ] = value
314+
315+ await asyncio .gather (* [fetch_shard (shard_index , items ) for shard_index , items in shard_groups .items ()])
316+ return results
317+
318+ def get_batch_sync (self , indices : Sequence [int ]) -> Sequence [np .ndarray ]:
319+ shard_groups = _group_indices_by_shard (indices , self ._row_boundaries ())
320+ results : list [None | np .ndarray ] = [None ] * len (indices )
321+ for shard_index , items in shard_groups .items ():
322+ local_indices = [local_index for _ , local_index in items ]
323+ batch = self ._stores [shard_index ].get_batch_sync (local_indices )
324+ for (position , _ ), value in zip (items , batch ):
325+ results [position ] = value
326+ return results
327+
328+ def _resolve_row (self , index : int ) -> tuple [int , int ]:
329+ boundaries = self ._row_boundaries ()
330+ if index < 0 :
331+ index += boundaries [- 1 ]
332+ if index < 0 or index >= boundaries [- 1 ]:
333+ raise IndexError ("Index out of bounds" )
334+ shard_index = bisect .bisect_right (boundaries , index ) - 1
335+ return shard_index , index - boundaries [shard_index ]
336+
337+ def _row_boundaries (self ):
338+ return _cumulative_offsets ([store .num_rows for store in self ._stores ])
339+
340+
341+ class ShardedTreeStore :
342+ """Virtual TreeStore backed by multiple shard-local TreeStores."""
343+
344+ def __init__ (self , stores : list [TreeStore ]):
345+ if not stores :
346+ raise ValueError ("ShardedTreeStore requires at least one store" )
347+ self .path = stores [0 ].path
348+ self .mode = "r"
349+ self ._stores = stores
350+ self .tree = jax .tree .map (
351+ lambda * leaves : ShardedJaggedArrayStore (list (leaves )), * [store .tree for store in stores ]
352+ )
353+
354+ def __len__ (self ):
355+ return len (jax .tree .leaves (self .tree )[0 ])
356+
357+ async def async_len (self ):
358+ return len (self )
359+
360+ def __getitem__ (self , item ):
361+ if isinstance (item , slice ):
362+ start , stop , step = item .indices (len (self ))
363+ return self .get_batch_sync (list (range (start , stop , step )))
364+ return jax .tree .map (lambda reader : reader [item ], self .tree )
365+
366+ async def get_batch (self , indices ) -> List [T_co ]:
367+ grouped = jax .tree .map (lambda reader : reader .get_batch (indices ), self .tree )
368+ leaves , structure = jax .tree .flatten (grouped )
369+ awaited_leaves = await asyncio .gather (* leaves )
370+ return [jax .tree .unflatten (structure , [leaf [i ] for leaf in awaited_leaves ]) for i in range (len (indices ))]
371+
372+ def get_batch_sync (self , indices ) -> List [T_co ]:
373+ grouped = jax .tree .map (lambda reader : reader .get_batch_sync (indices ), self .tree )
374+ return [jax .tree .map (lambda _ , leaf : leaf [i ], self .tree , grouped ) for i in range (len (indices ))]
375+
376+
163377class ShardedTreeCache (AsyncDataset [T_co ]):
164378 """Reads across multiple shard caches without requiring a consolidation step.
165379
@@ -181,6 +395,11 @@ def __init__(self, shard_paths: list[str], exemplar: T_co, ledger: "CacheLedger"
181395 rows = ledger .shard_rows .get (shard_name , 0 )
182396 self ._cum_rows .append (self ._cum_rows [- 1 ] + rows )
183397 self ._stores .append (TreeStore .open (exemplar , path , mode = "r" , cache_metadata = False ))
398+ self ._store = ShardedTreeStore (self ._stores )
399+
400+ @property
401+ def store (self ) -> ShardedTreeStore :
402+ return self ._store
184403
185404 def _resolve_index (self , global_idx : int ) -> tuple [int , int ]:
186405 """Return (shard_index, local_row) for a global row index."""
@@ -203,7 +422,10 @@ def __getitem__(self, item):
203422 shard_idx , local_idx = self ._resolve_index (item )
204423 return self ._stores [shard_idx ][local_idx ]
205424
206- async def get_batch (self , indices : Sequence [int ]) -> Sequence :
425+ async def get_batch (self , indices : Sequence [int ] | slice ) -> Sequence :
426+ if isinstance (indices , slice ):
427+ indices = range (indices .start or 0 , indices .stop or len (self ), indices .step or 1 )
428+
207429 # Group indices by shard, preserving original order
208430 shard_groups : dict [int , list [tuple [int , int ]]] = {} # shard_idx -> [(position_in_output, local_idx)]
209431 for pos , global_idx in enumerate (indices ):
@@ -248,6 +470,51 @@ def is_finished(self):
248470 return True
249471
250472
473+ def _cumulative_offsets (sizes : Sequence [int ]) -> list [int ]:
474+ offsets = [0 ]
475+ for size in sizes :
476+ offsets .append (offsets [- 1 ] + size )
477+ return offsets
478+
479+
480+ def _split_slice_by_boundaries (start : int , stop : int , boundaries : Sequence [int ]) -> list [tuple [int , slice ]]:
481+ if start >= stop :
482+ return []
483+ pieces = []
484+ shard_index = bisect .bisect_right (boundaries , start ) - 1
485+ while shard_index < len (boundaries ) - 1 and start < stop :
486+ shard_start = boundaries [shard_index ]
487+ shard_stop = boundaries [shard_index + 1 ]
488+ piece_stop = min (stop , shard_stop )
489+ if start < piece_stop :
490+ pieces .append ((shard_index , slice (start - shard_start , piece_stop - shard_start )))
491+ start = piece_stop
492+ shard_index += 1
493+ return pieces
494+
495+
496+ def _concatenate_or_empty (pieces : Sequence [np .ndarray ]) -> np .ndarray :
497+ if not pieces :
498+ return np .asarray ([])
499+ if len (pieces ) == 1 :
500+ return np .asarray (pieces [0 ])
501+ return np .concatenate (pieces )
502+
503+
504+ def _group_indices_by_shard (indices : Sequence [int ], boundaries : Sequence [int ]) -> dict [int , list [tuple [int , int ]]]:
505+ shard_groups : dict [int , list [tuple [int , int ]]] = {}
506+ total_rows = boundaries [- 1 ]
507+ for position , index in enumerate (indices ):
508+ if index < 0 :
509+ index += total_rows
510+ if index < 0 or index >= total_rows :
511+ raise IndexError ("Index out of bounds" )
512+ shard_index = bisect .bisect_right (boundaries , index ) - 1
513+ local_index = index - boundaries [shard_index ]
514+ shard_groups .setdefault (shard_index , []).append ((position , local_index ))
515+ return shard_groups
516+
517+
251518@dataclass_json
252519@dataclass
253520class CacheLedger :
0 commit comments