Skip to content

Commit fd42cb2

Browse files
nevillelyhclaude
andcommitted
Avoid per-item numpy conversion in JaggedArrayStore write path
TreeStore.extend and extend_with_batch were converting each item to a numpy array individually before passing to JaggedArrayStore.extend, which then concatenated them. For a batch of 16K tokenized sequences this means 16K np.asarray calls + one np.concatenate. Add PreparedBatch.from_sequences() that pre-allocates a single flat array from the cumulative lengths and copies each sequence directly into the right slice. JaggedArrayStore.extend now detects Python sequences (lists) and uses this fast path automatically. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 493c9bb commit fd42cb2

3 files changed

Lines changed: 67 additions & 10 deletions

File tree

lib/levanter/src/levanter/store/jagged_array.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,35 @@ def num_rows(self):
7575
return len(self.offsets)
7676

7777
@staticmethod
78-
def from_batch(items: Sequence[np.ndarray], item_rank: Optional[int] = None) -> "PreparedBatch":
78+
def from_batch(
79+
items: Sequence, item_rank: Optional[int] = None, dtype: Optional[np.dtype] = None
80+
) -> "PreparedBatch":
81+
if items and not hasattr(items[0], "ndim"):
82+
if (item_rank or 1) == 1:
83+
return PreparedBatch._from_sequences(items, dtype)
84+
items = [np.asarray(x) for x in items]
7985
data, offsets, shapes = _prepare_batch(items, item_rank)
8086
return PreparedBatch(data, offsets, shapes)
8187

88+
@staticmethod
89+
def _from_sequences(items: Sequence[Sequence], dtype: Optional[np.dtype]) -> "PreparedBatch":
90+
"""Build from Python sequences without per-item numpy conversion.
91+
Pre-allocates a single flat array and copies each sequence into it."""
92+
lengths = np.array([len(item) for item in items], dtype=np.int64)
93+
offsets = np.cumsum(lengths)
94+
total = int(offsets[-1]) if len(offsets) else 0
95+
if dtype is None:
96+
# Infer from first non-empty item
97+
probe = np.asarray(items[0][:1]) if items and len(items[0]) > 0 else np.asarray([0])
98+
dtype = probe.dtype
99+
data = np.empty(total, dtype=dtype)
100+
pos = 0
101+
for item, length in zip(items, lengths):
102+
end = pos + int(length)
103+
data[pos:end] = item
104+
pos = end
105+
return PreparedBatch(data, offsets, None)
106+
82107
@staticmethod
83108
def concat(batches: Sequence["PreparedBatch"]) -> "PreparedBatch":
84109
data = np.concatenate([batch.data for batch in batches])
@@ -282,11 +307,11 @@ def trim_to_size(self, size: int):
282307
self._cached_num_rows = size
283308
self._cached_data_size = new_max
284309

285-
async def extend_async(self, arrays: Sequence[np.ndarray] | PreparedBatch):
310+
async def extend_async(self, arrays: Sequence[np.ndarray] | Sequence[Sequence] | PreparedBatch):
286311
if isinstance(arrays, PreparedBatch):
287312
prepared = arrays
288313
else:
289-
prepared = PreparedBatch.from_batch(arrays, self.item_rank)
314+
prepared = PreparedBatch.from_batch(arrays, self.item_rank, dtype=np.dtype(self.data.dtype.name))
290315
data = prepared.data
291316
new_offsets = prepared.offsets
292317
shapes = prepared.shapes
@@ -313,11 +338,11 @@ async def extend_async(self, arrays: Sequence[np.ndarray] | PreparedBatch):
313338
self._cached_num_rows = num_rows + num_added
314339
self._cached_data_size = current_data_size + len(data)
315340

316-
def extend(self, arrays: Sequence[np.ndarray] | PreparedBatch):
341+
def extend(self, arrays: Sequence[np.ndarray] | Sequence[Sequence] | PreparedBatch):
317342
if isinstance(arrays, PreparedBatch):
318343
prepared = arrays
319344
else:
320-
prepared = PreparedBatch.from_batch(arrays, self.item_rank)
345+
prepared = PreparedBatch.from_batch(arrays, self.item_rank, dtype=np.dtype(self.data.dtype.name))
321346

322347
data = prepared.data
323348
new_offsets = prepared.offsets

lib/levanter/src/levanter/store/tree_store.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def extend(self, batch: Sequence[T]):
7070
Append a batch of data to the store.
7171
"""
7272
jtu.tree_map(
73-
lambda writer, *xs: writer.extend([np.asarray(x) for x in xs]),
73+
lambda writer, *xs: writer.extend(xs),
7474
self.tree,
7575
*batch,
7676
is_leaf=heuristic_is_leaf,
@@ -84,7 +84,7 @@ def extend_with_batch(self, batch: T):
8484
For instance, HF's BatchEncoding is a dict of lists of numpy arrays.
8585
"""
8686
jtu.tree_map(
87-
lambda writer, xs: writer.extend(xs if isinstance(xs, PreparedBatch) else [np.asarray(x) for x in xs]),
87+
lambda writer, xs: writer.extend(xs if isinstance(xs, PreparedBatch) else xs),
8888
self.tree,
8989
batch,
9090
is_leaf=heuristic_is_leaf_batched,
@@ -98,9 +98,7 @@ async def extend_with_batch_async(self, batch: T):
9898
For instance, HF's BatchEncoding is a dict of lists of numpy arrays.
9999
"""
100100
futures = jtu.tree_map(
101-
lambda writer, xs: writer.extend_async(
102-
xs if isinstance(xs, PreparedBatch) else [np.asarray(x) for x in xs]
103-
),
101+
lambda writer, xs: writer.extend_async(xs if isinstance(xs, PreparedBatch) else xs),
104102
self.tree,
105103
batch,
106104
is_leaf=heuristic_is_leaf_batched,

lib/levanter/tests/test_jagged_array.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,5 +396,39 @@ async def test_get_batch_empty():
396396
assert batch == []
397397

398398

399+
def test_extend_with_python_lists():
400+
"""Extending a JaggedArrayStore with Python lists should use the fast path
401+
(PreparedBatch.from_sequences) and produce identical results to numpy arrays."""
402+
with tempfile.TemporaryDirectory() as tmpdir:
403+
builder = JaggedArrayStore.open(tmpdir, item_rank=1, dtype=jnp.int32)
404+
405+
lists = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]
406+
builder.extend(lists)
407+
408+
assert len(builder) == 3
409+
np.testing.assert_array_equal(builder[0], np.array([1, 2, 3]))
410+
np.testing.assert_array_equal(builder[1], np.array([4, 5]))
411+
np.testing.assert_array_equal(builder[2], np.array([6, 7, 8, 9]))
412+
413+
# Extend again to verify offsets accumulate correctly
414+
builder.extend([[10, 11]])
415+
assert len(builder) == 4
416+
np.testing.assert_array_equal(builder[3], np.array([10, 11]))
417+
418+
419+
def test_from_batch_with_python_lists_matches_numpy():
420+
"""PreparedBatch.from_batch with Python lists should produce the same result as with numpy arrays."""
421+
lists = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]
422+
arrays = [np.array(lst, dtype=np.int32) for lst in lists]
423+
424+
from_lists = PreparedBatch.from_batch(lists, dtype=np.int32)
425+
from_arrays = PreparedBatch.from_batch(arrays)
426+
427+
np.testing.assert_array_equal(from_lists.data, from_arrays.data)
428+
np.testing.assert_array_equal(from_lists.offsets, from_arrays.offsets)
429+
assert from_lists.shapes is None
430+
assert from_arrays.shapes is None
431+
432+
399433
if __name__ == "__main__":
400434
pytest.main()

0 commit comments

Comments
 (0)