4444from .registry import _REGISTRY , IOSpec , read_elem , read_elem_partial
4545
4646if TYPE_CHECKING :
47- from collections .abc import Callable
47+ from collections .abc import Callable , Iterator
4848 from os import PathLike
4949 from typing import Any , Literal
5050
@@ -374,13 +374,12 @@ def write_list(
374374# It's in the `AnnData.concatenate` docstring, but should we keep it?
375375@_REGISTRY .register_write (H5Group , views .ArrayView , IOSpec ("array" , "0.2.0" ))
376376@_REGISTRY .register_write (H5Group , np .ndarray , IOSpec ("array" , "0.2.0" ))
377- @_REGISTRY .register_write (H5Group , h5py .Dataset , IOSpec ("array" , "0.2.0" ))
378377@_REGISTRY .register_write (H5Group , np .ma .MaskedArray , IOSpec ("array" , "0.2.0" ))
379378@_REGISTRY .register_write (ZarrGroup , views .ArrayView , IOSpec ("array" , "0.2.0" ))
380379@_REGISTRY .register_write (ZarrGroup , np .ndarray , IOSpec ("array" , "0.2.0" ))
381- @_REGISTRY .register_write (ZarrGroup , h5py .Dataset , IOSpec ("array" , "0.2.0" ))
382380@_REGISTRY .register_write (ZarrGroup , np .ma .MaskedArray , IOSpec ("array" , "0.2.0" ))
383381@_REGISTRY .register_write (ZarrGroup , ZarrArray , IOSpec ("array" , "0.2.0" ))
382+ @_REGISTRY .register_write (ZarrGroup , H5Array , IOSpec ("array" , "0.2.0" ))
384383@zero_dim_array_as_scalar
385384def write_basic (
386385 f : GroupStorageType ,
@@ -394,6 +393,51 @@ def write_basic(
394393 f .create_dataset (k , data = elem , ** dataset_kwargs )
395394
396395
396+ def _iter_chunks_for_copy (
397+ elem : ArrayStorageType , dest : ArrayStorageType
398+ ) -> Iterator [slice | tuple [list [slice ]]]:
399+ """
400+ Returns an iterator of tuples of slices for copying chunks from `elem` to `dest`.
401+
402+ * If `dest` has chunks, it will return the chunks of `dest`.
403+ * If `dest` is not chunked, we write it in ~100MB chunks or 1000 rows, whichever is larger.
404+ """
405+ if dest .chunks and hasattr (dest , "iter_chunks" ):
406+ return dest .iter_chunks ()
407+ else :
408+ shape = elem .shape
409+ # Number of rows that works out to
410+ n_rows = max (
411+ ad .settings .min_rows_for_chunked_h5_copy ,
412+ elem .chunks [0 ] if elem .chunks is not None else 1 ,
413+ )
414+ return (slice (i , min (i + n_rows , shape [0 ])) for i in range (0 , shape [0 ], n_rows ))
415+
416+
417+ @_REGISTRY .register_write (H5Group , H5Array , IOSpec ("array" , "0.2.0" ))
418+ @_REGISTRY .register_write (H5Group , ZarrArray , IOSpec ("array" , "0.2.0" ))
419+ def write_chunked_dense_array_to_group (
420+ f : GroupStorageType ,
421+ k : str ,
422+ elem : ArrayStorageType ,
423+ * ,
424+ _writer : Writer ,
425+ dataset_kwargs : Mapping [str , Any ] = MappingProxyType ({}),
426+ ):
427+ """Write to a h5py.Dataset in chunks.
428+
429+ `h5py.Group.create_dataset(..., data: h5py.Dataset)` will load all of `data` into memory
430+ before writing. Instead, we will write in chunks to avoid this. We don't need to do this for
431+ zarr since zarr handles this automatically.
432+ """
433+ dtype = dataset_kwargs .get ("dtype" , elem .dtype )
434+ kwargs = {** dataset_kwargs , "dtype" : dtype }
435+ dest = f .create_dataset (k , shape = elem .shape , ** kwargs )
436+
437+ for chunk in _iter_chunks_for_copy (elem , dest ):
438+ dest [chunk ] = elem [chunk ]
439+
440+
397441_REGISTRY .register_write (H5Group , CupyArray , IOSpec ("array" , "0.2.0" ))(
398442 _to_cpu_mem_wrapper (write_basic )
399443)
@@ -604,9 +648,15 @@ def write_sparse_compressed(
604648 if isinstance (f , H5Group ) and "maxshape" not in dataset_kwargs :
605649 dataset_kwargs = dict (maxshape = (None ,), ** dataset_kwargs )
606650
607- g .create_dataset ("data" , data = value .data , ** dataset_kwargs )
608- g .create_dataset ("indices" , data = value .indices , ** dataset_kwargs )
609- g .create_dataset ("indptr" , data = value .indptr , dtype = indptr_dtype , ** dataset_kwargs )
651+ for attr_name in ["data" , "indices" , "indptr" ]:
652+ attr = getattr (value , attr_name )
653+ dtype = indptr_dtype if attr_name == "indptr" else attr .dtype
654+ _writer .write_elem (
655+ g ,
656+ attr_name ,
657+ attr ,
658+ dataset_kwargs = {"dtype" : dtype , ** dataset_kwargs },
659+ )
610660
611661
612662write_csr = partial (write_sparse_compressed , fmt = "csr" )
0 commit comments