@@ -4,15 +4,18 @@ pub use dataset::{AnnDataSet, StackedAnnData};
44
55use crate :: {
66 ArrayElemOp , AxisArraysOp , ElemCollectionOp ,
7- backend:: { Backend , DataContainer , GroupOp , StoreOp } ,
7+ backend:: { Backend , DataContainer , DataType , GroupOp , StoreOp } ,
88 container:: { ArrayElem , Axis , AxisArrays , DataFrameElem , Dim , ElemCollection , Slot } ,
99 data:: * ,
1010 traits:: AnnDataOp ,
1111} ;
1212
1313use anyhow:: { Result , ensure} ;
1414use itertools:: Itertools ;
15- use std:: path:: { Path , PathBuf } ;
15+ use std:: {
16+ collections:: HashMap ,
17+ path:: { Path , PathBuf } ,
18+ } ;
1619
1720/// Represents an annotated data object backed by a specified backend.
1821pub struct AnnData < B : Backend > {
@@ -441,4 +444,99 @@ impl<B: Backend> AnnData<B> {
441444
442445 Ok ( ( ) )
443446 }
447+
448+ /// Split the AnnData object into multiple AnnData objects based on grouping keys.
449+ /// The length of the grouping keys must match the number of observations.
450+ ///
451+ /// # Arguments:
452+ /// * `keys` - A slice of grouping keys.
453+ /// * `out_dir` - The output directory to save the split AnnData files.
454+ ///
455+ /// # Returns:
456+ /// A HashMap mapping each unique key to its corresponding AnnData object.
457+ pub fn split_by < P : AsRef < Path > > ( & self , keys : & [ String ] , out_dir : P ) -> Result < HashMap < String , AnnData < B > > > {
458+ let n_vars = self . n_vars ( ) ;
459+ ensure ! (
460+ keys. len( ) == self . n_obs( ) ,
461+ "Length of grouping keys must match number of observations"
462+ ) ;
463+
464+ let mut indices = HashMap :: new ( ) ;
465+ keys. iter ( ) . enumerate ( ) . for_each ( |( i, key) | {
466+ indices. entry ( key) . or_insert ( Vec :: new ( ) ) . push ( i) ;
467+ } ) ;
468+
469+ // Create anndata files
470+ std:: fs:: create_dir_all ( & out_dir) ?;
471+ let adatas = indices
472+ . into_iter ( )
473+ . map ( |( key, idx) | {
474+ let path = out_dir. as_ref ( ) . join ( format ! ( "{}.h5ad" , key) ) ;
475+ let adata = AnnData :: < B > :: new ( & path) ?;
476+ adata. set_n_obs ( idx. len ( ) ) ?;
477+ adata. set_n_vars ( n_vars) ?;
478+ Ok ( ( key, ( adata, idx) ) )
479+ } )
480+ . collect :: < Result < HashMap < _ , _ > > > ( ) ?;
481+
482+ // Copy obs and var
483+ {
484+ let obs_names = self . obs_names ( ) ;
485+ let var_names = self . var_names ( ) ;
486+ let var = self . read_var ( ) ?;
487+ let obs = self . read_obs ( ) ?;
488+ adatas. iter ( ) . try_for_each ( |( _, ( adata, idx) ) | {
489+ let idx = SelectInfoElem :: from ( idx) ;
490+ if !obs_names. is_empty ( ) {
491+ adata. set_obs_names ( obs_names. select ( & idx) ) ?;
492+ }
493+ if !var_names. is_empty ( ) {
494+ adata. set_var_names ( var_names. clone ( ) ) ?;
495+ }
496+ adata. set_var ( var. clone ( ) ) ?;
497+ adata. set_obs ( Selectable :: select_axis ( & obs, 0 , idx) ) ?;
498+ anyhow:: Ok ( ( ) )
499+ } ) ?;
500+ }
501+
502+ // Copy X
503+ if let Some ( dtype) = self . x ( ) . dtype ( ) {
504+ let mut x_builders: HashMap < _ , _ > = match dtype {
505+ DataType :: Array ( t) => {
506+ adatas. iter ( ) . map ( |( key, ( adata, _) ) | {
507+ ( key, MatrixBuilder :: new_dense ( & adata. file , "X" , t) . unwrap ( ) )
508+ } ) . collect ( )
509+ }
510+ DataType :: CsrMatrix ( t) => {
511+ adatas. iter ( ) . map ( |( key, ( adata, _) ) | {
512+ ( key, MatrixBuilder :: new_sparse ( & adata. file , "X" , t) . unwrap ( ) )
513+ } ) . collect ( )
514+ }
515+ _ => anyhow:: bail!( "Unsupported data type for X" ) ,
516+ } ;
517+
518+ self . x ( ) . iter :: < ArrayData > ( 10000 ) . for_each ( |( chunk, start, end) | {
519+ let mut idx = HashMap :: new ( ) ;
520+ keys[ start..end] . iter ( ) . enumerate ( ) . for_each ( |( i, key) | {
521+ idx. entry ( key) . or_insert ( Vec :: new ( ) ) . push ( i) ;
522+ } ) ;
523+
524+ idx. into_iter ( ) . for_each ( |( key, rows) | {
525+ let arr = chunk. select_axis ( 0 , SelectInfoElem :: from ( & rows) ) ;
526+ x_builders. get_mut ( & key) . unwrap ( ) . add ( arr) . unwrap ( ) ;
527+ } ) ;
528+ } ) ;
529+
530+ x_builders. into_iter ( ) . try_for_each ( |( key, builder) | {
531+ adatas. get ( key) . unwrap ( ) . 0 . x . swap ( & builder. finish ( ) ?) ;
532+ anyhow:: Ok ( ( ) )
533+ } ) ?;
534+ }
535+
536+ adatas. into_iter ( )
537+ . map ( |( key, ( adata, _) ) | {
538+ Ok ( ( key. to_string ( ) , adata) )
539+ } )
540+ . collect ( )
541+ }
444542}
0 commit comments