2727from .._core .sparse_dataset import BaseCompressedSparseDataset , sparse_dataset
2828from .._io .specs import read_elem , write_elem
2929from ..compat import H5Array , H5Group , ZarrArray , ZarrGroup
30- from . import read_dispatched
30+ from . import read_dispatched , read_elem_lazy
3131
3232if TYPE_CHECKING :
3333 from collections .abc import Callable , Collection , Iterable , Sequence
@@ -173,7 +173,7 @@ def write_concat_dense( # noqa: PLR0917
173173 output_path : ZarrGroup | H5Group ,
174174 axis : Literal [0 , 1 ] = 0 ,
175175 reindexers : Reindexer | None = None ,
176- fill_value = None ,
176+ fill_value : Any = None ,
177177):
178178 """
179179 Writes the concatenation of given dense arrays to disk using dask.
@@ -206,7 +206,7 @@ def write_concat_sparse( # noqa: PLR0917
206206 max_loaded_elems : int ,
207207 axis : Literal [0 , 1 ] = 0 ,
208208 reindexers : Reindexer | None = None ,
209- fill_value = None ,
209+ fill_value : Any = None ,
210210):
211211 """
212212 Writes and concatenates sparse datasets into a single output dataset.
@@ -246,20 +246,20 @@ def write_concat_sparse( # noqa: PLR0917
246246
247247
248248def _write_concat_mappings ( # noqa: PLR0913, PLR0917
249- mappings ,
249+ mappings : Collection [ dict ] ,
250250 output_group : ZarrGroup | H5Group ,
251- keys ,
252- path ,
253- max_loaded_elems ,
254- axis = 0 ,
255- index = None ,
256- reindexers = None ,
257- fill_value = None ,
251+ keys : Collection [ str ] ,
252+ output_path : str | Path ,
253+ max_loaded_elems : int ,
254+ axis : Literal [ 0 , 1 ] = 0 ,
255+ index : pd . Index = None ,
256+ reindexers : list [ Reindexer ] | None = None ,
257+ fill_value : Any = None ,
258258):
259259 """
260260 Write a list of mappings to a zarr/h5 group.
261261 """
262- mapping_group = output_group .create_group (path )
262+ mapping_group = output_group .create_group (output_path )
263263 mapping_group .attrs .update ({
264264 "encoding-type" : "dict" ,
265265 "encoding-version" : "0.1.0" ,
@@ -280,13 +280,13 @@ def _write_concat_mappings( # noqa: PLR0913, PLR0917
280280
281281def _write_concat_arrays ( # noqa: PLR0913, PLR0917
282282 arrays : Sequence [ZarrArray | H5Array | BaseCompressedSparseDataset ],
283- output_group ,
284- output_path ,
285- max_loaded_elems ,
286- axis = 0 ,
287- reindexers = None ,
288- fill_value = None ,
289- join = "inner" ,
283+ output_group : ZarrGroup | H5Group ,
284+ output_path : str | Path ,
285+ max_loaded_elems : int ,
286+ axis : Literal [ 0 , 1 ] = 0 ,
287+ reindexers : list [ Reindexer ] | None = None ,
288+ fill_value : Any = None ,
289+ join : Literal [ "inner" , "outer" ] = "inner" ,
290290):
291291 init_elem = arrays [0 ]
292292 init_type = type (init_elem )
@@ -324,14 +324,14 @@ def _write_concat_arrays( # noqa: PLR0913, PLR0917
324324
325325def _write_concat_sequence ( # noqa: PLR0913, PLR0917
326326 arrays : Sequence [pd .DataFrame | BaseCompressedSparseDataset | H5Array | ZarrArray ],
327- output_group ,
328- output_path ,
329- max_loaded_elems ,
330- axis = 0 ,
331- index = None ,
332- reindexers = None ,
333- fill_value = None ,
334- join = "inner" ,
327+ output_group : ZarrGroup | H5Group ,
328+ output_path : str | Path ,
329+ max_loaded_elems : int ,
330+ axis : Literal [ 0 , 1 ] = 0 ,
331+ index : pd . Index = None ,
332+ reindexers : list [ Reindexer ] | None = None ,
333+ fill_value : Any = None ,
334+ join : Literal [ "inner" , "outer" ] = "inner" ,
335335):
336336 """
337337 array, dataframe, csc_matrix, csc_matrix
@@ -376,17 +376,27 @@ def _write_concat_sequence( # noqa: PLR0913, PLR0917
376376 raise NotImplementedError (msg )
377377
378378
379- def _write_alt_mapping (groups , output_group , alt_axis_name , alt_indices , merge ):
380- alt_mapping = merge ([read_as_backed (g [alt_axis_name ]) for g in groups ])
381- # If its empty, we need to write an empty dataframe with the correct index
382- if not alt_mapping :
383- alt_df = pd .DataFrame (index = alt_indices )
384- write_elem (output_group , alt_axis_name , alt_df )
385- else :
386- write_elem (output_group , alt_axis_name , alt_mapping )
379+ def _write_alt_mapping (
380+ groups : Collection [H5Group , ZarrGroup ],
381+ output_group : ZarrGroup | H5Group ,
382+ alt_axis_name : Literal ["obs" , "var" ],
383+ merge : Callable ,
384+ reindexers : list [Reindexer ],
385+ ):
386+ alt_mapping = merge ([
387+ {k : r (read_elem (v ), axis = 0 ) for k , v in dict (g [f"{ alt_axis_name } m" ]).items ()}
388+ for r , g in zip (reindexers , groups , strict = True )
389+ ])
390+ write_elem (output_group , f"{ alt_axis_name } m" , alt_mapping )
387391
388392
389- def _write_alt_annot (groups , output_group , alt_axis_name , alt_indices , merge ):
393+ def _write_alt_annot (
394+ groups : Collection [H5Group , ZarrGroup ],
395+ output_group : ZarrGroup | H5Group ,
396+ alt_axis_name : Literal ["obs" , "var" ],
397+ alt_indices : pd .Index ,
398+ merge : Callable ,
399+ ):
390400 # Annotation for other axis
391401 alt_annot = merge_dataframes (
392402 [read_elem (g [alt_axis_name ]) for g in groups ], alt_indices , merge
@@ -395,7 +405,13 @@ def _write_alt_annot(groups, output_group, alt_axis_name, alt_indices, merge):
395405
396406
397407def _write_axis_annot ( # noqa: PLR0917
398- groups , output_group , axis_name , concat_indices , label , label_col , join
408+ groups : Collection [H5Group , ZarrGroup ],
409+ output_group : ZarrGroup | H5Group ,
410+ axis_name : Literal ["obs" , "var" ],
411+ concat_indices : pd .Index ,
412+ label : str ,
413+ label_col : str ,
414+ join : Literal ["inner" , "outer" ],
399415):
400416 concat_annot = pd .concat (
401417 unify_dtypes (read_elem (g [axis_name ]) for g in groups ),
@@ -408,6 +424,23 @@ def _write_axis_annot( # noqa: PLR0917
408424 write_elem (output_group , axis_name , concat_annot )
409425
410426
427+ def _write_alt_pairwise (
428+ groups : Collection [H5Group , ZarrGroup ],
429+ output_group : ZarrGroup | H5Group ,
430+ alt_axis_name : Literal ["obs" , "var" ],
431+ merge : Callable ,
432+ reindexers : list [Reindexer ],
433+ ):
434+ alt_pairwise = merge ([
435+ {
436+ k : r (r (read_elem_lazy (v ), axis = 0 ), axis = 1 )
437+ for k , v in dict (g [f"{ alt_axis_name } p" ]).items ()
438+ }
439+ for r , g in zip (reindexers , groups , strict = True )
440+ ])
441+ write_elem (output_group , f"{ alt_axis_name } p" , alt_pairwise )
442+
443+
411444def concat_on_disk ( # noqa: PLR0912, PLR0913, PLR0915
412445 in_files : Collection [PathLike [str ] | str ] | Mapping [str , PathLike [str ] | str ],
413446 out_file : PathLike [str ] | str ,
@@ -490,7 +523,8 @@ def concat_on_disk( # noqa: PLR0912, PLR0913, PLR0915
490523 DataFrames are padded with missing values.
491524 pairwise
492525 Whether pairwise elements along the concatenated dimension should be included.
493- This is False by default, since the resulting arrays are often not meaningful.
526+ This is False by default, since the resulting arrays are often not meaningful, and raises {class}`NotImplementedError` when True.
527+ If you are interested in this feature, please open an issue.
494528
495529 Notes
496530 -----
@@ -634,7 +668,10 @@ def concat_on_disk( # noqa: PLR0912, PLR0913, PLR0915
634668 _write_alt_annot (groups , output_group , alt_axis_name , alt_index , merge )
635669
636670 # Write {alt_axis_name}m
637- _write_alt_mapping (groups , output_group , alt_axis_name , alt_index , merge )
671+ _write_alt_mapping (groups , output_group , alt_axis_name , merge , reindexers )
672+
673+ # Write {alt_axis_name}p
674+ _write_alt_pairwise (groups , output_group , alt_axis_name , merge , reindexers )
638675
639676 # Write X
640677
0 commit comments