Skip to content

Commit ca9e81d

Browse files
committed
add split_by
1 parent 1e398f6 commit ca9e81d

6 files changed

Lines changed: 665 additions & 106 deletions

File tree

anndata/src/anndata.rs

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@ pub use dataset::{AnnDataSet, StackedAnnData};
44

55
use crate::{
66
ArrayElemOp, AxisArraysOp, ElemCollectionOp,
7-
backend::{Backend, DataContainer, GroupOp, StoreOp},
7+
backend::{Backend, DataContainer, DataType, GroupOp, StoreOp},
88
container::{ArrayElem, Axis, AxisArrays, DataFrameElem, Dim, ElemCollection, Slot},
99
data::*,
1010
traits::AnnDataOp,
1111
};
1212

1313
use anyhow::{Result, ensure};
1414
use itertools::Itertools;
15-
use std::path::{Path, PathBuf};
15+
use std::{
16+
collections::HashMap,
17+
path::{Path, PathBuf},
18+
};
1619

1720
/// Represents an annotated data object backed by a specified backend.
1821
pub struct AnnData<B: Backend> {
@@ -441,4 +444,99 @@ impl<B: Backend> AnnData<B> {
441444

442445
Ok(())
443446
}
447+
448+
/// Split the AnnData object into multiple AnnData objects based on grouping keys.
449+
/// The length of the grouping keys must match the number of observations.
450+
///
451+
/// # Arguments:
452+
/// * `keys` - A slice of grouping keys.
453+
/// * `out_dir` - The output directory to save the split AnnData files.
454+
///
455+
/// # Returns:
456+
/// A HashMap mapping each unique key to its corresponding AnnData object.
457+
pub fn split_by<P: AsRef<Path>>(&self, keys: &[String], out_dir: P) -> Result<HashMap<String, AnnData<B>>> {
458+
let n_vars = self.n_vars();
459+
ensure!(
460+
keys.len() == self.n_obs(),
461+
"Length of grouping keys must match number of observations"
462+
);
463+
464+
let mut indices = HashMap::new();
465+
keys.iter().enumerate().for_each(|(i, key)| {
466+
indices.entry(key).or_insert(Vec::new()).push(i);
467+
});
468+
469+
// Create anndata files
470+
std::fs::create_dir_all(&out_dir)?;
471+
let adatas = indices
472+
.into_iter()
473+
.map(|(key, idx)| {
474+
let path = out_dir.as_ref().join(format!("{}.h5ad", key));
475+
let adata = AnnData::<B>::new(&path)?;
476+
adata.set_n_obs(idx.len())?;
477+
adata.set_n_vars(n_vars)?;
478+
Ok((key, (adata, idx)))
479+
})
480+
.collect::<Result<HashMap<_, _>>>()?;
481+
482+
// Copy obs and var
483+
{
484+
let obs_names = self.obs_names();
485+
let var_names = self.var_names();
486+
let var = self.read_var()?;
487+
let obs = self.read_obs()?;
488+
adatas.iter().try_for_each(|(_, (adata, idx))| {
489+
let idx = SelectInfoElem::from(idx);
490+
if !obs_names.is_empty() {
491+
adata.set_obs_names(obs_names.select(&idx))?;
492+
}
493+
if !var_names.is_empty() {
494+
adata.set_var_names(var_names.clone())?;
495+
}
496+
adata.set_var(var.clone())?;
497+
adata.set_obs(Selectable::select_axis(&obs, 0, idx))?;
498+
anyhow::Ok(())
499+
})?;
500+
}
501+
502+
// Copy X
503+
if let Some(dtype) = self.x().dtype() {
504+
let mut x_builders: HashMap<_, _> = match dtype {
505+
DataType::Array(t) => {
506+
adatas.iter().map(|(key, (adata, _))| {
507+
(key, MatrixBuilder::new_dense(&adata.file, "X", t).unwrap())
508+
}).collect()
509+
}
510+
DataType::CsrMatrix(t) => {
511+
adatas.iter().map(|(key, (adata, _))| {
512+
(key, MatrixBuilder::new_sparse(&adata.file, "X", t).unwrap())
513+
}).collect()
514+
}
515+
_ => anyhow::bail!("Unsupported data type for X"),
516+
};
517+
518+
self.x().iter::<ArrayData>(10000).for_each(|(chunk, start, end)| {
519+
let mut idx = HashMap::new();
520+
keys[start..end].iter().enumerate().for_each(|(i, key)| {
521+
idx.entry(key).or_insert(Vec::new()).push(i);
522+
});
523+
524+
idx.into_iter().for_each(|(key, rows)| {
525+
let arr = chunk.select_axis(0, SelectInfoElem::from(&rows));
526+
x_builders.get_mut(&key).unwrap().add(arr).unwrap();
527+
});
528+
});
529+
530+
x_builders.into_iter().try_for_each(|(key, builder)| {
531+
adatas.get(key).unwrap().0.x.swap(&builder.finish()?);
532+
anyhow::Ok(())
533+
})?;
534+
}
535+
536+
adatas.into_iter()
537+
.map(|(key, (adata, _))| {
538+
Ok((key.to_string(), adata))
539+
})
540+
.collect()
541+
}
444542
}

anndata/src/data/array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ pub mod slice;
55
mod sparse;
66
pub mod utils;
77

8-
pub use chunks::ArrayChunk;
8+
pub use chunks::{MatrixBuilder, ArrayChunk};
99
pub use dataframe::DataFrameIndex;
1010
pub use dense::{ArrayConvert, CategoricalArray, DynArray, DynCowArray, DynScalar};
1111
pub use slice::{SelectInfo, SelectInfoBounds, SelectInfoElem, SelectInfoElemBounds, Shape};

0 commit comments

Comments
 (0)