1- use crate :: backend:: ScalarType ;
21use crate :: data:: utils:: { array_major_minor_index_default, cs_major_minor_index2} ;
32use crate :: data:: { DataFrameIndex , DynCsrMatrix } ;
4- use crate :: { AnnDataOp , ArrayElemOp } ;
3+ use crate :: { AnnDataOp , ArrayElemOp , AxisArraysOp , HasShape } ;
54use anyhow:: { ensure, Result } ;
65use indexmap:: IndexSet ;
76use itertools:: Itertools ;
7+ use log:: warn;
88use nalgebra_sparse:: csr:: CsrMatrix ;
99use nalgebra_sparse:: pattern:: SparsityPattern ;
1010use polars:: chunked_array:: builder:: CategoricalChunkedBuilder ;
@@ -104,44 +104,51 @@ where
104104 }
105105
106106 // Concatenate X
107+ if adatas. iter ( ) . any ( |adata| adata. x ( ) . is_none ( ) ) {
108+ warn ! ( "Some AnnData objects have no X matrix. The concatenated X matrix will be None." ) ;
109+ } else {
110+ out. set_x_from_iter ( concat_x ( adatas, & common_vars) ) ?;
111+ }
112+
113+ // Concatenate obsm
107114 {
108- if adatas. iter ( ) . any ( |adata| !adata. x ( ) . is_none ( ) ) {
109- let dtype = adatas
110- . iter ( )
111- . flat_map ( |x| x. x ( ) . dtype ( ) . and_then ( |d| d. scalar_type ( ) ) )
112- . next ( )
113- . unwrap ( ) ;
114- let x_arr = adatas. iter ( ) . map ( |adata| {
115- let n_obs = adata. n_obs ( ) ;
116- let n_vars = adata. n_vars ( ) ;
117- let var_names = adata. var_names ( ) ;
115+ let obsm: Vec < _ > = adatas. iter ( ) . map ( |x| x. obsm ( ) ) . collect ( ) ;
116+ let common_keys = obsm
117+ . iter ( )
118+ . map ( |x| x. keys ( ) . into_iter ( ) . collect :: < IndexSet < _ > > ( ) )
119+ . reduce ( |a, b| a. intersection ( & b) . cloned ( ) . collect ( ) )
120+ . unwrap ( ) ;
121+ for key in common_keys {
122+ let arr = concat_axis_arrays ( & obsm, & key) ;
123+ out. obsm ( ) . add_iter ( & key, arr) ?;
124+ }
125+ }
118126
119- macro_rules! fun {
120- ( $variant: ident) => {
121- CsrMatrix :: <$variant>:: zeros( n_obs, n_vars) . into( )
122- } ;
123- }
127+ // Concatenate obsp
128+ {
129+ let obsp: Vec < _ > = adatas. iter ( ) . map ( |x| x. obsp ( ) ) . collect ( ) ;
130+ let common_keys = obsp
131+ . iter ( )
132+ . map ( |x| x. keys ( ) . into_iter ( ) . collect :: < IndexSet < _ > > ( ) )
133+ . reduce ( |a, b| a. intersection ( & b) . cloned ( ) . collect ( ) )
134+ . unwrap ( ) ;
135+ for key in common_keys {
136+ let arr = concat_axis_arrays ( & obsp, & key) ;
137+ out. obsp ( ) . add_iter ( & key, arr) ?;
138+ }
139+ }
124140
125- adata
126- . x ( )
127- . get ( )
128- . unwrap ( )
129- . map ( |arr| {
130- index_array (
131- arr,
132- & ( 0 ..adata. n_obs ( ) )
133- . into_iter ( )
134- . map ( |x| Some ( x) )
135- . collect :: < Vec < _ > > ( ) ,
136- & common_vars
137- . iter ( )
138- . map ( |x| var_names. get_index ( x) )
139- . collect :: < Vec < _ > > ( ) ,
140- )
141- } )
142- . unwrap_or_else ( || crate :: macros:: dyn_match!( dtype, ScalarType , fun) )
143- } ) ;
144- out. set_x_from_iter ( x_arr) ?;
141+ // Concat layers
142+ {
143+ let layers: Vec < _ > = adatas. iter ( ) . map ( |x| x. layers ( ) ) . collect ( ) ;
144+ let common_keys = layers
145+ . iter ( )
146+ . map ( |x| x. keys ( ) . into_iter ( ) . collect :: < IndexSet < _ > > ( ) )
147+ . reduce ( |a, b| a. intersection ( & b) . cloned ( ) . collect ( ) )
148+ . unwrap ( ) ;
149+ for key in common_keys {
150+ let arr = concat_axis_arrays ( & layers, & key) ;
151+ out. layers ( ) . add_iter ( & key, arr) ?;
145152 }
146153 }
147154
@@ -292,3 +299,30 @@ fn index_array(
292299 _ => todo ! ( ) ,
293300 }
294301}
302+
303+ fn concat_x < A : AnnDataOp > ( adatas : & [ A ] , common_vars : & IndexSet < String > ) -> impl Iterator < Item = ArrayData > {
304+ adatas. iter ( ) . map ( move |adata| {
305+ let var_names = adata. var_names ( ) ;
306+ let arr = adata. x ( ) . get ( ) . unwrap ( ) . unwrap ( ) ;
307+ index_array (
308+ arr,
309+ & ( 0 ..adata. n_obs ( ) )
310+ . into_iter ( )
311+ . map ( |x| Some ( x) )
312+ . collect :: < Vec < _ > > ( ) ,
313+ & common_vars
314+ . iter ( )
315+ . map ( |x| var_names. get_index ( x) )
316+ . collect :: < Vec < _ > > ( ) ,
317+ )
318+ } )
319+ }
320+
321+ fn concat_axis_arrays < A : AxisArraysOp > ( axis_arrays : & [ A ] , key : & str ) -> impl Iterator < Item = ArrayData > {
322+ let size = axis_arrays[ 0 ] . get ( key) . unwrap ( ) . shape ( ) . unwrap ( ) [ 1 ] ;
323+ axis_arrays. iter ( ) . map ( move |arr| {
324+ let arr: ArrayData = arr. get_item ( key) . unwrap ( ) . unwrap ( ) ;
325+ assert_eq ! ( arr. shape( ) [ 1 ] , size, "dimension mismatch for key: {}" , key) ;
326+ arr
327+ } )
328+ }
0 commit comments