Skip to content

Commit 3395c4c

Browse files
nevillelyhclaude
andcommitted
Avoid per-item numpy conversion in JaggedArrayStore write path
TreeStore.extend and extend_with_batch were converting each item to a numpy array individually before passing to JaggedArrayStore.extend, which then concatenated them. For a batch of 16K tokenized sequences this means 16K np.asarray calls + one np.concatenate. Add PreparedBatch.from_sequences() that pre-allocates a single flat array from the cumulative lengths and copies each sequence directly into the right slice. JaggedArrayStore.extend now detects Python sequences (lists) and uses this fast path automatically. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 493c9bb commit 3395c4c

3 files changed

Lines changed: 63 additions & 10 deletions

File tree

lib/levanter/src/levanter/store/jagged_array.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,30 @@ def num_rows(self):
7575
return len(self.offsets)
7676

7777
@staticmethod
78-
def from_batch(items: Sequence[np.ndarray], item_rank: Optional[int] = None) -> "PreparedBatch":
78+
def from_batch(items: Sequence, item_rank: Optional[int] = None) -> "PreparedBatch":
79+
if items and not hasattr(items[0], "ndim"):
80+
if (item_rank or 1) == 1:
81+
return PreparedBatch._from_sequences(items)
82+
items = [np.asarray(x) for x in items]
7983
data, offsets, shapes = _prepare_batch(items, item_rank)
8084
return PreparedBatch(data, offsets, shapes)
8185

86+
@staticmethod
87+
def _from_sequences(items: Sequence[Sequence]) -> "PreparedBatch":
88+
"""Build from Python sequences without per-item numpy conversion.
89+
Pre-allocates a single flat array and copies each sequence into it."""
90+
lengths = np.array([len(item) for item in items], dtype=np.int64)
91+
offsets = np.cumsum(lengths)
92+
total = int(offsets[-1]) if len(offsets) else 0
93+
probe = np.asarray(items[0][:1]) if items and len(items[0]) > 0 else np.asarray([0])
94+
data = np.empty(total, dtype=probe.dtype)
95+
pos = 0
96+
for item, length in zip(items, lengths):
97+
end = pos + int(length)
98+
data[pos:end] = item
99+
pos = end
100+
return PreparedBatch(data, offsets, None)
101+
82102
@staticmethod
83103
def concat(batches: Sequence["PreparedBatch"]) -> "PreparedBatch":
84104
data = np.concatenate([batch.data for batch in batches])
@@ -282,12 +302,12 @@ def trim_to_size(self, size: int):
282302
self._cached_num_rows = size
283303
self._cached_data_size = new_max
284304

285-
async def extend_async(self, arrays: Sequence[np.ndarray] | PreparedBatch):
305+
async def extend_async(self, arrays: Sequence[np.ndarray] | Sequence[Sequence] | PreparedBatch):
286306
if isinstance(arrays, PreparedBatch):
287307
prepared = arrays
288308
else:
289309
prepared = PreparedBatch.from_batch(arrays, self.item_rank)
290-
data = prepared.data
310+
data = np.asarray(prepared.data, dtype=self.data.dtype.name)
291311
new_offsets = prepared.offsets
292312
shapes = prepared.shapes
293313

@@ -313,13 +333,13 @@ async def extend_async(self, arrays: Sequence[np.ndarray] | PreparedBatch):
313333
self._cached_num_rows = num_rows + num_added
314334
self._cached_data_size = current_data_size + len(data)
315335

316-
def extend(self, arrays: Sequence[np.ndarray] | PreparedBatch):
336+
def extend(self, arrays: Sequence[np.ndarray] | Sequence[Sequence] | PreparedBatch):
317337
if isinstance(arrays, PreparedBatch):
318338
prepared = arrays
319339
else:
320340
prepared = PreparedBatch.from_batch(arrays, self.item_rank)
321341

322-
data = prepared.data
342+
data = np.asarray(prepared.data, dtype=self.data.dtype.name)
323343
new_offsets = prepared.offsets
324344
shapes = prepared.shapes
325345

lib/levanter/src/levanter/store/tree_store.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def extend(self, batch: Sequence[T]):
7070
Append a batch of data to the store.
7171
"""
7272
jtu.tree_map(
73-
lambda writer, *xs: writer.extend([np.asarray(x) for x in xs]),
73+
lambda writer, *xs: writer.extend(xs),
7474
self.tree,
7575
*batch,
7676
is_leaf=heuristic_is_leaf,
@@ -84,7 +84,7 @@ def extend_with_batch(self, batch: T):
8484
For instance, HF's BatchEncoding is a dict of lists of numpy arrays.
8585
"""
8686
jtu.tree_map(
87-
lambda writer, xs: writer.extend(xs if isinstance(xs, PreparedBatch) else [np.asarray(x) for x in xs]),
87+
lambda writer, xs: writer.extend(xs if isinstance(xs, PreparedBatch) else xs),
8888
self.tree,
8989
batch,
9090
is_leaf=heuristic_is_leaf_batched,
@@ -98,9 +98,7 @@ async def extend_with_batch_async(self, batch: T):
9898
For instance, HF's BatchEncoding is a dict of lists of numpy arrays.
9999
"""
100100
futures = jtu.tree_map(
101-
lambda writer, xs: writer.extend_async(
102-
xs if isinstance(xs, PreparedBatch) else [np.asarray(x) for x in xs]
103-
),
101+
lambda writer, xs: writer.extend_async(xs if isinstance(xs, PreparedBatch) else xs),
104102
self.tree,
105103
batch,
106104
is_leaf=heuristic_is_leaf_batched,

lib/levanter/tests/test_jagged_array.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,5 +396,40 @@ async def test_get_batch_empty():
396396
assert batch == []
397397

398398

399+
def test_extend_with_python_lists():
400+
"""Extending a JaggedArrayStore with Python lists should use the fast path
401+
(PreparedBatch.from_sequences) and produce identical results to numpy arrays."""
402+
with tempfile.TemporaryDirectory() as tmpdir:
403+
builder = JaggedArrayStore.open(tmpdir, item_rank=1, dtype=jnp.int32)
404+
405+
lists = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]
406+
builder.extend(lists)
407+
408+
assert len(builder) == 3
409+
np.testing.assert_array_equal(builder[0], np.array([1, 2, 3]))
410+
np.testing.assert_array_equal(builder[1], np.array([4, 5]))
411+
np.testing.assert_array_equal(builder[2], np.array([6, 7, 8, 9]))
412+
413+
# Extend again to verify offsets accumulate correctly
414+
builder.extend([[10, 11]])
415+
assert len(builder) == 4
416+
np.testing.assert_array_equal(builder[3], np.array([10, 11]))
417+
418+
419+
def test_from_batch_with_python_lists_matches_numpy():
420+
"""PreparedBatch.from_batch with Python lists should produce the same result as with numpy arrays."""
421+
lists = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]
422+
arrays = [np.array(lst, dtype=np.int32) for lst in lists]
423+
424+
from_lists = PreparedBatch.from_batch(lists)
425+
from_arrays = PreparedBatch.from_batch(arrays)
426+
427+
# dtype may differ (int64 inferred vs int32 explicit) but values must match
428+
np.testing.assert_array_equal(from_lists.data, from_arrays.data)
429+
np.testing.assert_array_equal(from_lists.offsets, from_arrays.offsets)
430+
assert from_lists.shapes is None
431+
assert from_arrays.shapes is None
432+
433+
399434
if __name__ == "__main__":
400435
pytest.main()

0 commit comments

Comments
 (0)