Skip to content

Commit 36c6f58

Browse files
committed
modify concat
1 parent ba7495d commit 36c6f58

3 files changed

Lines changed: 101 additions & 46 deletions

File tree

anndata/src/concat.rs

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
use crate::backend::ScalarType;
21
use crate::data::utils::{array_major_minor_index_default, cs_major_minor_index2};
32
use crate::data::{DataFrameIndex, DynCsrMatrix};
4-
use crate::{AnnDataOp, ArrayElemOp};
3+
use crate::{AnnDataOp, ArrayElemOp, AxisArraysOp, HasShape};
54
use anyhow::{ensure, Result};
65
use indexmap::IndexSet;
76
use itertools::Itertools;
7+
use log::warn;
88
use nalgebra_sparse::csr::CsrMatrix;
99
use nalgebra_sparse::pattern::SparsityPattern;
1010
use polars::chunked_array::builder::CategoricalChunkedBuilder;
@@ -104,44 +104,51 @@ where
104104
}
105105

106106
// Concatenate X
107+
if adatas.iter().any(|adata| adata.x().is_none()) {
108+
warn!("Some AnnData objects have no X matrix. The concatenated X matrix will be None.");
109+
} else {
110+
out.set_x_from_iter(concat_x(adatas, &common_vars))?;
111+
}
112+
113+
// Concatenate obsm
107114
{
108-
if adatas.iter().any(|adata| !adata.x().is_none()) {
109-
let dtype = adatas
110-
.iter()
111-
.flat_map(|x| x.x().dtype().and_then(|d| d.scalar_type()))
112-
.next()
113-
.unwrap();
114-
let x_arr = adatas.iter().map(|adata| {
115-
let n_obs = adata.n_obs();
116-
let n_vars = adata.n_vars();
117-
let var_names = adata.var_names();
115+
let obsm: Vec<_> = adatas.iter().map(|x| x.obsm()).collect();
116+
let common_keys = obsm
117+
.iter()
118+
.map(|x| x.keys().into_iter().collect::<IndexSet<_>>())
119+
.reduce(|a, b| a.intersection(&b).cloned().collect())
120+
.unwrap();
121+
for key in common_keys {
122+
let arr = concat_axis_arrays(&obsm, &key);
123+
out.obsm().add_iter(&key, arr)?;
124+
}
125+
}
118126

119-
macro_rules! fun {
120-
($variant:ident) => {
121-
CsrMatrix::<$variant>::zeros(n_obs, n_vars).into()
122-
};
123-
}
127+
// Concatenate obsp
128+
{
129+
let obsp: Vec<_> = adatas.iter().map(|x| x.obsp()).collect();
130+
let common_keys = obsp
131+
.iter()
132+
.map(|x| x.keys().into_iter().collect::<IndexSet<_>>())
133+
.reduce(|a, b| a.intersection(&b).cloned().collect())
134+
.unwrap();
135+
for key in common_keys {
136+
let arr = concat_axis_arrays(&obsp, &key);
137+
out.obsp().add_iter(&key, arr)?;
138+
}
139+
}
124140

125-
adata
126-
.x()
127-
.get()
128-
.unwrap()
129-
.map(|arr| {
130-
index_array(
131-
arr,
132-
&(0..adata.n_obs())
133-
.into_iter()
134-
.map(|x| Some(x))
135-
.collect::<Vec<_>>(),
136-
&common_vars
137-
.iter()
138-
.map(|x| var_names.get_index(x))
139-
.collect::<Vec<_>>(),
140-
)
141-
})
142-
.unwrap_or_else(|| crate::macros::dyn_match!(dtype, ScalarType, fun))
143-
});
144-
out.set_x_from_iter(x_arr)?;
141+
// Concat layers
142+
{
143+
let layers: Vec<_> = adatas.iter().map(|x| x.layers()).collect();
144+
let common_keys = layers
145+
.iter()
146+
.map(|x| x.keys().into_iter().collect::<IndexSet<_>>())
147+
.reduce(|a, b| a.intersection(&b).cloned().collect())
148+
.unwrap();
149+
for key in common_keys {
150+
let arr = concat_axis_arrays(&layers, &key);
151+
out.layers().add_iter(&key, arr)?;
145152
}
146153
}
147154

@@ -292,3 +299,30 @@ fn index_array(
292299
_ => todo!(),
293300
}
294301
}
302+
303+
fn concat_x<A: AnnDataOp>(adatas: &[A], common_vars: &IndexSet<String>) -> impl Iterator<Item = ArrayData> {
304+
adatas.iter().map(move |adata| {
305+
let var_names = adata.var_names();
306+
let arr = adata.x().get().unwrap().unwrap();
307+
index_array(
308+
arr,
309+
&(0..adata.n_obs())
310+
.into_iter()
311+
.map(|x| Some(x))
312+
.collect::<Vec<_>>(),
313+
&common_vars
314+
.iter()
315+
.map(|x| var_names.get_index(x))
316+
.collect::<Vec<_>>(),
317+
)
318+
})
319+
}
320+
321+
fn concat_axis_arrays<A: AxisArraysOp>(axis_arrays: &[A], key: &str) -> impl Iterator<Item = ArrayData> {
322+
let size = axis_arrays[0].get(key).unwrap().shape().unwrap()[1];
323+
axis_arrays.iter().map(move |arr| {
324+
let arr: ArrayData = arr.get_item(key).unwrap().unwrap();
325+
assert_eq!(arr.shape()[1], size, "dimension mismatch for key: {}", key);
326+
arr
327+
})
328+
}

pyanndata/src/anndata.rs

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,36 @@ pub fn read<'py>(
7575
}
7676

7777
/// Concatenates AnnData objects.
78+
///
79+
/// When `file` is provided, this function saves the merged AnnData object on disk
80+
/// in a streaming fashion. This is memory efficient and allows merging large datasets
81+
/// that do not fit into memory.
7882
///
7983
/// Parameters
8084
/// ----------
8185
///
82-
/// filename: Path
83-
/// File name of data file.
84-
/// backed: Literal['r', 'r+'] | None
85-
/// Default is `r+`.
86-
/// If `'r'`, the file is opened in read-only mode.
87-
/// If `'r+'`, the file is opened in read/write mode.
88-
/// If `None`, the AnnData object is read into memory.
89-
/// backend: Literal['hdf5'] | None
86+
/// adatas: list[AnnData]
87+
/// List of AnnData objects to concatenate.
88+
/// join: Literal['inner', 'outer']
89+
/// How to handle observations and variables that are not shared between all AnnData objects.
90+
/// label: str | None
91+
/// Column in axis annotation (i.e. .obs or .var) to place batch information in. If it’s None, no column is added.
92+
/// keys
93+
/// Names for each object being added. These values are used for column values for label.
94+
/// file: Path | None
95+
/// If provided, the concatenated AnnData will be saved to this file.
96+
/// backend: Literal['hdf5', 'zarr']
97+
/// Backend to use for writing the output file.
98+
///
99+
/// Returns
100+
/// -------
101+
///
102+
/// AnnData
103+
/// The concatenated AnnData object.
104+
///
105+
/// See Also
106+
/// --------
107+
/// AnnDataSet
90108
#[pyfunction]
91109
#[pyo3(
92110
signature = (adatas, *, join="inner", label=None, keys=None, file=None, backend=None),

pyanndata/src/anndata/dataset.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,16 @@ use super::get_backend;
4040
4141
Note
4242
----
43-
AnnDataSet does not copy underlying AnnData objects. It stores the references
43+
- AnnDataSet does not copy underlying AnnData objects. It stores the references
4444
to individual anndata files. If you move the anndata files to a new location,
4545
remember to update the anndata file locations when opening an AnnDataSet object.
46+
- AnnDataSet requires all component anndata files to have the same set of var names.
47+
To concatenate AnnData objects with different var names, please use `concat` function.
4648
4749
See Also
4850
--------
4951
read_dataset
52+
concat
5053
*/
5154
#[pyclass]
5255
#[repr(transparent)]

0 commit comments

Comments
 (0)