diff --git a/arrayloaders/io/zarr_loader.py b/arrayloaders/io/zarr_loader.py index 26ee044..fc852a7 100644 --- a/arrayloaders/io/zarr_loader.py +++ b/arrayloaders/io/zarr_loader.py @@ -32,6 +32,8 @@ def split_given_size(a: np.ndarray, size: int) -> list[np.ndarray]: + if size == 1: + return a return np.split(a, np.arange(size, len(a), size)) @@ -301,10 +303,14 @@ def iter( ) # In order to handle data returned where (chunk_size * preload_nchunks) mod batch_size != 0 # we must keep track of the leftover data. - in_memory_data = None + chunks: list[InputInMemoryArray] = [] in_memory_labels = None in_memory_indices = None - mod = self.sp_module if issubclass(self.dataset_type, ad.abc.CSRDataset) else np + vstack = ( + self.sp_module.vstack + if issubclass(self.dataset_type, ad.abc.CSRDataset) + else np.vstack + ) for chunk_indices in _batched( self._get_chunks(chunk_size, worker_handle, shuffle), preload_nchunks ): @@ -317,18 +323,9 @@ def iter( ] dataset_index_to_slices = self._slices_to_slices_with_array_index(slices) # Fetch the data over slices - chunks: list[InputInMemoryArray] = zsync.sync( - index_datasets(dataset_index_to_slices, fetch_data) - ) - if any(isinstance(c, CSRContainer) for c in chunks): - chunks_converted: list[OutputInMemoryArray] = [ - self.sp_module.csr_matrix( - tuple(self.np_module.asarray(e) for e in c.elems), shape=c.shape - ) - for c in chunks - ] - else: - chunks_converted = [self.np_module.asarray(c) for c in chunks] + chunks += zsync.sync(index_datasets(dataset_index_to_slices, fetch_data)) + chunks_converted = self._to_output_array(chunks) + # Accumulate labels labels: None | list[np.ndarray] = None if self.labels is not None: @@ -364,11 +361,6 @@ def iter( for index in dataset_indices ] # Do batch returns, handling leftover data as necessary - in_memory_data = ( - mod.vstack(chunks_converted) - if in_memory_data is None - else mod.vstack([in_memory_data, *chunks_converted]) - ) if self.labels is not None: in_memory_labels = ( np.concatenate(labels) @@ -384,41 +376,96 @@ def iter( # Create random indices into in_memory_data and then index into it # If there is "leftover" at the end (see the modulo op), # save it for the next iteration. - batch_indices = np.arange(in_memory_data.shape[0]) - if shuffle: - np.random.default_rng().shuffle(batch_indices) - splits = split_given_size(batch_indices, self._batch_size) - for i, s in enumerate(splits): - if s.shape[0] == self._batch_size: - res = [ - in_memory_data[s], - in_memory_labels[s] if self.labels is not None else None, - ] - if self._return_index: - res += [in_memory_indices[s]] - yield tuple(res) - if i == ( - len(splits) - 1 - ): # end of iteration, leftover data needs be kept - if (s.shape[0] % self._batch_size) != 0: - in_memory_data = in_memory_data[s] - if in_memory_labels is not None: - in_memory_labels = in_memory_labels[s] - if in_memory_indices is not None: - in_memory_indices = in_memory_indices[s] - else: - in_memory_data = None - in_memory_labels = None - in_memory_indices = None - if in_memory_data is not None: # handle any leftover data + if self._batch_size != (num_obs := sum(c.shape[0] for c in chunks)): + batch_indices = np.arange(num_obs) + if shuffle: + np.random.default_rng().shuffle(batch_indices) + splits = split_given_size(batch_indices, self._batch_size) + for i, s in enumerate(splits): + s, chunks_reindexed = self._reindex_against_integer_indices( + s, chunks_converted + ) + if s.shape[0] == self._batch_size: + res = [ + vstack(chunks_reindexed) + if len(chunks_reindexed) > 1 + else chunks_reindexed[0], + in_memory_labels[s] if self.labels is not None else None, + ] + if self._return_index: + res += [in_memory_indices[s]] + yield tuple(res) + if i == ( + len(splits) - 1 + ): # end of iteration, leftover data needs be kept + if (s.shape[0] % self._batch_size) != 0: + chunks = chunks_reindexed + if in_memory_labels is not None: + in_memory_labels = in_memory_labels[s] + if in_memory_indices is not None: + in_memory_indices = in_memory_indices[s] + else: + chunks = [] + in_memory_labels = None + in_memory_indices = None + elif len(chunks_converted) > 0: # handle batch size matches in-memory + res = [ + vstack(chunks_converted), + in_memory_labels if self.labels is not None else None, + ] + if self._return_index: + res += [in_memory_indices] + yield tuple(res) + chunks = [] + in_memory_labels = None + in_memory_indices = None + if len(chunks) > 0: # handle leftover data res = [ - in_memory_data, - in_memory_labels if self.labels is not None else None, + vstack(self._to_output_array(chunks)), + np.asarray(in_memory_labels) if self.labels is not None else None, ] if self._return_index: - res += [in_memory_indices] + res += [np.asarray(in_memory_indices)] yield tuple(res) + def _to_output_array( + self, chunks: list[InputInMemoryArray | OutputInMemoryArray] + ) -> list[OutputInMemoryArray]: + if any(isinstance(c, CSRContainer) for c in chunks): + return [ + self.sp_module.csr_matrix( + tuple(self.np_module.asarray(e) for e in c.elems), shape=c.shape + ) + if isinstance(c, CSRContainer) + else c + for c in chunks + ] + elif any(isinstance(c, np.ndarray) for c in chunks): + return [self.np_module.asarray(c) for c in chunks] + return chunks + + def _reindex_against_integer_indices( + self, indices: np.ndarray, chunks: list[OutputInMemoryArray] + ) -> tuple[np.ndarray, list[OutputInMemoryArray]]: + upper_bounds = np.cumsum(np.array([c.shape[0] for c in chunks])) + lower_bounds = np.concatenate([np.array([0]), upper_bounds[:-1]]) + reindexed, chunks_reindexed = list( + zip( + *( + (reindexed, c[self.np_module.asarray(reindexed - lower)]) + for c, upper, lower in zip( + chunks, upper_bounds, lower_bounds, strict=False + ) + if ( + reindexed := indices[(indices < upper) & (indices >= lower)] + ).shape[0] + > 0 + ), + strict=False, + ) + ) + return np.concatenate(reindexed), list(chunks_reindexed) + AnnDataManager.add_datasets.__doc__ = add_dataset_docstring AnnDataManager.add_dataset.__doc__ = add_dataset_docstring