Skip to content

Commit b95eddb

Browse files
nevillelyhclaude
andcommitted
Avoid per-item numpy conversion in JaggedArrayStore write path
TreeStore.extend and extend_with_batch were converting each item to a numpy array individually before passing to JaggedArrayStore.extend, which then concatenated them. For a batch of 16K tokenized sequences this means 16K np.asarray calls + one np.concatenate. Add PreparedBatch.from_sequences() that pre-allocates a single flat array from the cumulative lengths and copies each sequence directly into the right slice. JaggedArrayStore.extend now detects Python sequences (lists) and uses this fast path automatically. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 493c9bb commit b95eddb

4 files changed

Lines changed: 64 additions & 56 deletions

File tree

lib/levanter/src/levanter/store/jagged_array.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,30 @@ def num_rows(self):
7575
return len(self.offsets)
7676

7777
@staticmethod
78-
def from_batch(items: Sequence[np.ndarray], item_rank: Optional[int] = None) -> "PreparedBatch":
78+
def from_batch(items: Sequence, item_rank: Optional[int] = None) -> "PreparedBatch":
79+
if items and not hasattr(items[0], "ndim"):
80+
if (item_rank or 1) == 1:
81+
return PreparedBatch._from_sequences(items)
82+
items = [np.asarray(x) for x in items]
7983
data, offsets, shapes = _prepare_batch(items, item_rank)
8084
return PreparedBatch(data, offsets, shapes)
8185

86+
@staticmethod
87+
def _from_sequences(items: Sequence[Sequence]) -> "PreparedBatch":
88+
"""Build from Python sequences without per-item numpy conversion.
89+
Pre-allocates a single flat array and copies each sequence into it."""
90+
lengths = np.array([len(item) for item in items], dtype=np.int64)
91+
offsets = np.cumsum(lengths)
92+
total = int(offsets[-1]) if len(offsets) else 0
93+
dtype = np.result_type(items[0][0]) if items and len(items[0]) > 0 else np.int64
94+
data = np.empty(total, dtype=dtype)
95+
pos = 0
96+
for item, length in zip(items, lengths):
97+
end = pos + int(length)
98+
data[pos:end] = item
99+
pos = end
100+
return PreparedBatch(data, offsets, None)
101+
82102
@staticmethod
83103
def concat(batches: Sequence["PreparedBatch"]) -> "PreparedBatch":
84104
data = np.concatenate([batch.data for batch in batches])
@@ -205,10 +225,10 @@ async def data_size_async(self):
205225
self._cached_data_size = result
206226
return result
207227

208-
async def append_async(self, data: np.ndarray):
228+
async def append_async(self, data: Sequence):
209229
await self.extend_async([data])
210230

211-
def append(self, data: np.ndarray):
231+
def append(self, data: Sequence):
212232
self.extend([data])
213233

214234
async def trim_to_size_async(self, size: int):
@@ -282,7 +302,7 @@ def trim_to_size(self, size: int):
282302
self._cached_num_rows = size
283303
self._cached_data_size = new_max
284304

285-
async def extend_async(self, arrays: Sequence[np.ndarray] | PreparedBatch):
305+
async def extend_async(self, arrays: Sequence[Sequence] | PreparedBatch):
286306
if isinstance(arrays, PreparedBatch):
287307
prepared = arrays
288308
else:
@@ -313,7 +333,7 @@ async def extend_async(self, arrays: Sequence[np.ndarray] | PreparedBatch):
313333
self._cached_num_rows = num_rows + num_added
314334
self._cached_data_size = current_data_size + len(data)
315335

316-
def extend(self, arrays: Sequence[np.ndarray] | PreparedBatch):
336+
def extend(self, arrays: Sequence[Sequence] | PreparedBatch):
317337
if isinstance(arrays, PreparedBatch):
318338
prepared = arrays
319339
else:

lib/levanter/src/levanter/store/tree_store.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from haliax.jax_utils import is_jax_array_like
1313
from jaxtyping import PyTree
1414

15-
from .jagged_array import JaggedArrayStore, PreparedBatch
15+
from .jagged_array import JaggedArrayStore
1616

1717
T = TypeVar("T", bound=PyTree)
1818

@@ -50,10 +50,6 @@ def __init__(self, tree, path: str, mode: str):
5050
self.mode = mode
5151
self.tree = tree
5252

53-
@property
54-
def batch_preparer(self):
55-
return TreeBatchPreparer(jtu.tree_map(lambda writer: 9, self.tree, is_leaf=heuristic_is_leaf))
56-
5753
@staticmethod
5854
def open(exemplar: T, path: str, *, mode="a", cache_metadata: bool = False) -> "TreeStore":
5955
"""
@@ -70,7 +66,7 @@ def extend(self, batch: Sequence[T]):
7066
Append a batch of data to the store.
7167
"""
7268
jtu.tree_map(
73-
lambda writer, *xs: writer.extend([np.asarray(x) for x in xs]),
69+
lambda writer, *xs: writer.extend(xs),
7470
self.tree,
7571
*batch,
7672
is_leaf=heuristic_is_leaf,
@@ -84,7 +80,7 @@ def extend_with_batch(self, batch: T):
8480
For instance, HF's BatchEncoding is a dict of lists of numpy arrays.
8581
"""
8682
jtu.tree_map(
87-
lambda writer, xs: writer.extend(xs if isinstance(xs, PreparedBatch) else [np.asarray(x) for x in xs]),
83+
lambda writer, xs: writer.extend(xs),
8884
self.tree,
8985
batch,
9086
is_leaf=heuristic_is_leaf_batched,
@@ -98,9 +94,7 @@ async def extend_with_batch_async(self, batch: T):
9894
For instance, HF's BatchEncoding is a dict of lists of numpy arrays.
9995
"""
10096
futures = jtu.tree_map(
101-
lambda writer, xs: writer.extend_async(
102-
xs if isinstance(xs, PreparedBatch) else [np.asarray(x) for x in xs]
103-
),
97+
lambda writer, xs: writer.extend_async(xs),
10498
self.tree,
10599
batch,
106100
is_leaf=heuristic_is_leaf_batched,
@@ -205,16 +199,3 @@ def _render_path_elem(x):
205199
return f"{i}"
206200
case _:
207201
return str(x)
208-
209-
210-
class TreeBatchPreparer(Generic[T]):
211-
def __init__(self, exemplar: T):
212-
self.exemplar = exemplar
213-
214-
def __call__(self, batch: List[T]) -> PyTree:
215-
return jtu.tree_map(
216-
lambda _, *xs: PreparedBatch.from_batch([np.asarray(x) for x in xs]),
217-
self.exemplar,
218-
*batch,
219-
is_leaf=heuristic_is_leaf,
220-
)

lib/levanter/tests/test_jagged_array.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,5 +396,40 @@ async def test_get_batch_empty():
396396
assert batch == []
397397

398398

399+
def test_extend_with_python_lists():
400+
"""Extending a JaggedArrayStore with Python lists should use the fast path
401+
(PreparedBatch.from_sequences) and produce identical results to numpy arrays."""
402+
with tempfile.TemporaryDirectory() as tmpdir:
403+
builder = JaggedArrayStore.open(tmpdir, item_rank=1, dtype=jnp.int64)
404+
405+
lists = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]
406+
builder.extend(lists)
407+
408+
assert len(builder) == 3
409+
np.testing.assert_array_equal(builder[0], np.array([1, 2, 3]))
410+
np.testing.assert_array_equal(builder[1], np.array([4, 5]))
411+
np.testing.assert_array_equal(builder[2], np.array([6, 7, 8, 9]))
412+
413+
# Extend again to verify offsets accumulate correctly
414+
builder.extend([[10, 11]])
415+
assert len(builder) == 4
416+
np.testing.assert_array_equal(builder[3], np.array([10, 11]))
417+
418+
419+
def test_from_batch_with_python_lists_matches_numpy():
420+
"""PreparedBatch.from_batch with Python lists should produce the same result as with numpy arrays."""
421+
lists = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]
422+
arrays = [np.array(lst, dtype=np.int64) for lst in lists]
423+
424+
from_lists = PreparedBatch.from_batch(lists)
425+
from_arrays = PreparedBatch.from_batch(arrays)
426+
427+
# dtype may differ (int64 inferred vs int32 explicit) but values must match
428+
np.testing.assert_array_equal(from_lists.data, from_arrays.data)
429+
np.testing.assert_array_equal(from_lists.offsets, from_arrays.offsets)
430+
assert from_lists.shapes is None
431+
assert from_arrays.shapes is None
432+
433+
399434
if __name__ == "__main__":
400435
pytest.main()

lib/levanter/tests/test_tree_store.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -257,34 +257,6 @@ def test_reading_from_written():
257257
pytest.fail("Unexpected index")
258258

259259

260-
def test_using_prepared_batches():
261-
with tempfile.TemporaryDirectory() as tmpdir:
262-
exemplar = {"a": np.array([0], dtype=np.float64), "b": np.array([0], dtype=np.float64)}
263-
builder = TreeStore.open(exemplar, tmpdir, mode="w")
264-
preparer = builder.batch_preparer
265-
266-
batch = [
267-
{"a": np.array([1.0, 2.0]), "b": np.array([3.0, 4.0])},
268-
{"a": np.array([5.0, 6.0]), "b": np.array([7.0, 8.0])},
269-
]
270-
batch = preparer(batch)
271-
builder.extend_with_batch(batch)
272-
273-
del builder
274-
275-
builder2 = TreeStore.open(exemplar, tmpdir, mode="r")
276-
277-
for i, result in enumerate(builder2):
278-
if i == 0:
279-
assert np.all(result["a"] == np.array([1.0, 2.0]))
280-
assert np.all(result["b"] == np.array([3.0, 4.0]))
281-
elif i == 1:
282-
assert np.all(result["a"] == np.array([5.0, 6.0]))
283-
assert np.all(result["b"] == np.array([7.0, 8.0]))
284-
else:
285-
pytest.fail("Unexpected index")
286-
287-
288260
def test_resolve_changed_cache_size():
289261
with tempfile.TemporaryDirectory() as tmpdir:
290262
exemplar = {"a": np.array([0], dtype=np.float64), "b": np.array([0], dtype=np.float64)}

0 commit comments

Comments
 (0)