Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 96 additions & 49 deletions arrayloaders/io/zarr_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@


def split_given_size(a: np.ndarray, size: int) -> list[np.ndarray]:
if size == 1:
return a
return np.split(a, np.arange(size, len(a), size))


Expand Down Expand Up @@ -301,10 +303,14 @@ def iter(
)
# In order to handle data returned where (chunk_size * preload_nchunks) mod batch_size != 0
# we must keep track of the leftover data.
in_memory_data = None
chunks: list[InputInMemoryArray] = []
in_memory_labels = None
in_memory_indices = None
mod = self.sp_module if issubclass(self.dataset_type, ad.abc.CSRDataset) else np
vstack = (
self.sp_module.vstack
if issubclass(self.dataset_type, ad.abc.CSRDataset)
else np.vstack
)
for chunk_indices in _batched(
self._get_chunks(chunk_size, worker_handle, shuffle), preload_nchunks
):
Expand All @@ -317,18 +323,9 @@ def iter(
]
dataset_index_to_slices = self._slices_to_slices_with_array_index(slices)
# Fetch the data over slices
chunks: list[InputInMemoryArray] = zsync.sync(
index_datasets(dataset_index_to_slices, fetch_data)
)
if any(isinstance(c, CSRContainer) for c in chunks):
chunks_converted: list[OutputInMemoryArray] = [
self.sp_module.csr_matrix(
tuple(self.np_module.asarray(e) for e in c.elems), shape=c.shape
)
for c in chunks
]
else:
chunks_converted = [self.np_module.asarray(c) for c in chunks]
chunks += zsync.sync(index_datasets(dataset_index_to_slices, fetch_data))
chunks_converted = self._to_output_array(chunks)

# Accumulate labels
labels: None | list[np.ndarray] = None
if self.labels is not None:
Expand Down Expand Up @@ -364,11 +361,6 @@ def iter(
for index in dataset_indices
]
# Do batch returns, handling leftover data as necessary
in_memory_data = (
mod.vstack(chunks_converted)
if in_memory_data is None
else mod.vstack([in_memory_data, *chunks_converted])
)
if self.labels is not None:
in_memory_labels = (
np.concatenate(labels)
Expand All @@ -384,41 +376,96 @@ def iter(
# Create random indices into in_memory_data and then index into it
# If there is "leftover" at the end (see the modulo op),
# save it for the next iteration.
batch_indices = np.arange(in_memory_data.shape[0])
if shuffle:
np.random.default_rng().shuffle(batch_indices)
splits = split_given_size(batch_indices, self._batch_size)
for i, s in enumerate(splits):
if s.shape[0] == self._batch_size:
res = [
in_memory_data[s],
in_memory_labels[s] if self.labels is not None else None,
]
if self._return_index:
res += [in_memory_indices[s]]
yield tuple(res)
if i == (
len(splits) - 1
): # end of iteration, leftover data needs be kept
if (s.shape[0] % self._batch_size) != 0:
in_memory_data = in_memory_data[s]
if in_memory_labels is not None:
in_memory_labels = in_memory_labels[s]
if in_memory_indices is not None:
in_memory_indices = in_memory_indices[s]
else:
in_memory_data = None
in_memory_labels = None
in_memory_indices = None
if in_memory_data is not None: # handle any leftover data
if self._batch_size != (num_obs := sum(c.shape[0] for c in chunks)):
batch_indices = np.arange(num_obs)
if shuffle:
np.random.default_rng().shuffle(batch_indices)
splits = split_given_size(batch_indices, self._batch_size)
for i, s in enumerate(splits):
s, chunks_reindexed = self._reindex_against_integer_indices(
s, chunks_converted
)
if s.shape[0] == self._batch_size:
res = [
vstack(chunks_reindexed)
if len(chunks_reindexed) > 1
else chunks_reindexed[0],
in_memory_labels[s] if self.labels is not None else None,
]
if self._return_index:
res += [in_memory_indices[s]]
yield tuple(res)
if i == (
len(splits) - 1
): # end of iteration, leftover data needs be kept
if (s.shape[0] % self._batch_size) != 0:
chunks = chunks_reindexed
if in_memory_labels is not None:
in_memory_labels = in_memory_labels[s]
if in_memory_indices is not None:
in_memory_indices = in_memory_indices[s]
else:
chunks = []
in_memory_labels = None
in_memory_indices = None
elif len(chunks_converted) > 0: # handle batch size matches in-memory
res = [
vstack(chunks_converted),
in_memory_labels if self.labels is not None else None,
]
if self._return_index:
res += [in_memory_indices]
yield tuple(res)
chunks = []
in_memory_labels = None
in_memory_indices = None
if len(chunks) > 0: # handle leftover data
res = [
in_memory_data,
in_memory_labels if self.labels is not None else None,
vstack(self._to_output_array(chunks)),
np.asarray(in_memory_labels) if self.labels is not None else None,
]
if self._return_index:
res += [in_memory_indices]
res += [np.asarray(in_memory_indices)]
yield tuple(res)

def _to_output_array(
self, chunks: list[InputInMemoryArray | OutputInMemoryArray]
) -> list[OutputInMemoryArray]:
if any(isinstance(c, CSRContainer) for c in chunks):
return [
self.sp_module.csr_matrix(
tuple(self.np_module.asarray(e) for e in c.elems), shape=c.shape
)
if isinstance(c, CSRContainer)
else c
for c in chunks
]
elif any(isinstance(c, np.ndarray) for c in chunks):
return [self.np_module.asarray(c) for c in chunks]
return chunks

def _reindex_against_integer_indices(
self, indices: np.ndarray, chunks: list[OutputInMemoryArray]
) -> tuple[np.ndarray, list[OutputInMemoryArray]]:
upper_bounds = np.cumsum(np.array([c.shape[0] for c in chunks]))
lower_bounds = np.concatenate([np.array([0]), upper_bounds[:-1]])
reindexed, chunks_reindexed = list(
zip(
*(
(reindexed, c[self.np_module.asarray(reindexed - lower)])
for c, upper, lower in zip(
chunks, upper_bounds, lower_bounds, strict=False
)
if (
reindexed := indices[(indices < upper) & (indices >= lower)]
).shape[0]
> 0
),
strict=False,
)
)
return np.concatenate(reindexed), list(chunks_reindexed)


AnnDataManager.add_datasets.__doc__ = add_dataset_docstring
AnnDataManager.add_dataset.__doc__ = add_dataset_docstring
Expand Down