Skip to content

Commit 97c61b8

Browse files
committed
allow partial copying and writing of AnnData objects
1 parent ca9e81d commit 97c61b8

3 files changed

Lines changed: 115 additions & 43 deletions

File tree

anndata/src/anndata.rs

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::{
1313
use anyhow::{Result, ensure};
1414
use itertools::Itertools;
1515
use std::{
16-
collections::HashMap,
16+
collections::{HashMap, HashSet},
1717
path::{Path, PathBuf},
1818
};
1919

@@ -260,42 +260,68 @@ impl<B: Backend> AnnData<B> {
260260
/// # Arguments
261261
///
262262
/// * `filename` - The path to the output file.
263+
/// * `partial` - If Some, writes only the specified fields. If None, writes all fields.
264+
/// This can be useful for saving space when only a subset of the data is needed.
265+
/// Possible fields are: "X", "obs", "var", "obsm", "obsp", "varm", "varp", "uns", "layers".
263266
/// * `chunk_size` - If None, writes the entire data matrix at once. Otherwise,
264267
/// writes the data matrix in chunks of the specified size.
265268
/// This can be useful for saving large datasets that do not fit into memory.
266269
pub fn write<O: Backend, P: AsRef<Path>>(
267270
&self,
268271
filename: P,
272+
partial: Option<HashSet<String>>,
269273
chunk_size: Option<usize>,
270274
) -> Result<()> {
275+
let saved_fields = match partial {
276+
Some(set) => set,
277+
None => [
278+
"X", "obs", "var", "obsm", "obsp", "varm", "varp", "uns", "layers",
279+
]
280+
.into_iter()
281+
.map(String::from)
282+
.collect(),
283+
};
284+
271285
let adata = AnnData::<O>::new(filename)?;
272286

273287
adata.set_n_obs(self.n_obs())?;
274288
adata.set_n_vars(self.n_vars())?;
275289

276-
if !self.get_obs().is_none() {
290+
if !self.get_obs().is_none() && saved_fields.contains("obs") {
277291
adata.set_obs_names(self.obs_names())?;
278292
adata.set_obs(self.read_obs()?)?;
279293
}
280-
if !self.get_var().is_none() {
294+
if !self.get_var().is_none() && saved_fields.contains("var") {
281295
adata.set_var_names(self.var_names())?;
282296
adata.set_var(self.read_var()?)?;
283297
}
284298

285-
if !self.x().is_none() {
299+
if !self.x().is_none() && saved_fields.contains("X") {
286300
if let Some(chunk_size) = chunk_size {
287301
adata.set_x_from_iter(self.x().iter::<ArrayData>(chunk_size).map(|x| x.0))?;
288302
} else {
289303
adata.set_x(self.x().get::<ArrayData>()?.unwrap())?;
290304
}
291305
}
292306

293-
adata.set_obsm(self.obsm().iter_item::<ArrayData>())?;
294-
adata.set_obsp(self.obsp().iter_item::<ArrayData>())?;
295-
adata.set_varm(self.varm().iter_item::<ArrayData>())?;
296-
adata.set_varp(self.varp().iter_item::<ArrayData>())?;
297-
adata.set_uns(self.uns().iter_item::<Data>())?;
298-
adata.set_layers(self.layers().iter_item::<ArrayData>())?;
307+
if saved_fields.contains("obsm") {
308+
adata.set_obsm(self.obsm().iter_item::<ArrayData>())?;
309+
}
310+
if saved_fields.contains("obsp") {
311+
adata.set_obsp(self.obsp().iter_item::<ArrayData>())?;
312+
}
313+
if saved_fields.contains("varm") {
314+
adata.set_varm(self.varm().iter_item::<ArrayData>())?;
315+
}
316+
if saved_fields.contains("varp") {
317+
adata.set_varp(self.varp().iter_item::<ArrayData>())?;
318+
}
319+
if saved_fields.contains("uns") {
320+
adata.set_uns(self.uns().iter_item::<Data>())?;
321+
}
322+
if saved_fields.contains("layers") {
323+
adata.set_layers(self.layers().iter_item::<ArrayData>())?;
324+
}
299325

300326
adata.close()
301327
}

anndata/src/anndata/dataset.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ impl<B: Backend> AnnDataSet<B> {
302302

303303
/// Convert AnnDataSet to AnnData object
304304
pub fn to_adata<O: Backend, P: AsRef<Path>>(&self, out: P, copy_x: bool) -> Result<AnnData<O>> {
305-
self.annotation.write::<O, _>(&out, None)?;
305+
self.annotation.write::<O, _>(&out, None, None)?;
306306
let adata = AnnData::open(O::open_rw(&out)?)?;
307307
if copy_x {
308308
adata.set_x_from_iter::<_, ArrayData>(

pyanndata/src/anndata/backed.rs

Lines changed: 78 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@ use crate::anndata::PyAnnData;
22
use crate::container::{
33
PyArrayElem, PyAxisArrays, PyChunkedArray, PyDataFrameElem, PyElemCollection,
44
};
5-
use crate::data::{isinstance_of_pandas, to_select_elem, PyArrayData, PyData};
5+
use crate::data::{PyArrayData, PyData, isinstance_of_pandas, to_select_elem};
66

77
use anndata::container::Slot;
88
use anndata::data::{DataFrameIndex, SelectInfoElem, SelectInfoElemBounds};
99
use anndata::{self, ArrayElemOp, AxisArraysOp, Data, ElemCollectionOp, Selectable};
1010
use anndata::{AnnDataOp, ArrayData, Backend};
1111
use anndata_hdf5::H5;
12-
use anyhow::{bail, Result};
13-
use downcast_rs::{impl_downcast, Downcast};
12+
use anyhow::{Result, bail};
13+
use downcast_rs::{Downcast, impl_downcast};
1414
use pyo3::prelude::*;
15+
use pyo3::types::PyIterator;
1516
use pyo3_polars::PyDataFrame;
16-
use std::collections::HashMap;
17+
use std::collections::{HashMap, HashSet};
1718
use std::ops::Deref;
1819
use std::path::PathBuf;
1920

@@ -430,14 +431,14 @@ impl AnnData {
430431

431432
/// Split the AnnData object into multiple AnnData objects based on grouping keys.
432433
/// The length of the grouping keys must match the number of observations.
433-
///
434+
///
434435
/// Parameters
435436
/// ----------
436437
/// key: str | list[str]
437438
/// Grouping key(s). Can be a single column name or a list of column names in `obs`.
438439
/// out_dir: Path
439440
/// Output directory to save the split AnnData files.
440-
///
441+
///
441442
/// Returns
442443
/// -------
443444
/// dict[str, AnnData]
@@ -531,16 +532,24 @@ impl AnnData {
531532
/// File name of the output `.h5ad` file.
532533
/// backend: Literal['hdf5', 'zarr']
533534
/// The backend to use. "hdf5" or "zarr" are supported.
535+
/// partial : list[str] | None
536+
/// A list of fields to copy. If None, copies all fields. Possible fields are:
537+
/// "X", "obs", "var", "obsm", "obsp", "varm", "varp", "uns", "layers".
534538
/// chunk_size : int | None
535539
/// If None, writes the entire data matrix at once. Otherwise,
536540
/// rites the data matrix in chunks of the specified size.
537541
/// This can be useful for saving large datasets that do not fit into memory.
538542
#[pyo3(
539-
signature = (filename, backend=H5::NAME, chunk_size=None),
540-
text_signature = "($self, filename, backend='hdf5', chunk_size=None)",
543+
signature = (filename, backend=H5::NAME, partial=None, chunk_size=None),
544+
text_signature = "($self, filename, backend='hdf5', partial=None, chunk_size=None)",
541545
)]
542-
pub fn write(&self, filename: PathBuf, backend: &str, chunk_size: Option<usize>) -> Result<()> {
543-
self.0.write(filename, backend, chunk_size)
546+
pub fn write(&self, filename: PathBuf, backend: &str, partial: Option<Bound<PyIterator>>, chunk_size: Option<usize>) -> Result<()> {
547+
let partial = partial.map(|p| {
548+
p.into_iter()
549+
.map(|x| x.unwrap().extract::<String>().unwrap())
550+
.collect::<HashSet<_>>()
551+
});
552+
self.0.write(filename, backend, partial, chunk_size)
544553
}
545554

546555
/// Copy the AnnData object.
@@ -551,6 +560,9 @@ impl AnnData {
551560
/// File name of the output `.h5ad` file.
552561
/// backend: Literal['hdf5', 'zarr']
553562
/// The backend to use. "hdf5" or "zarr" are supported.
563+
/// partial : list[str] | None
564+
/// A list of fields to copy. If None, copies all fields. Possible fields are:
565+
/// "X", "obs", "var", "obsm", "obsp", "varm", "varp", "uns", "layers".
554566
/// chunk_size : int | None
555567
/// If None, writes the entire data matrix at once. Otherwise,
556568
/// rites the data matrix in chunks of the specified size.
@@ -560,11 +572,16 @@ impl AnnData {
560572
/// -------
561573
/// AnnData
562574
#[pyo3(
563-
signature = (filename, backend=H5::NAME, chunk_size=None),
564-
text_signature = "($self, filename, backend='hdf5', chunk_size=None)",
575+
signature = (filename, backend=H5::NAME, partial=None, chunk_size=None),
576+
text_signature = "($self, filename, backend='hdf5', partial=None, chunk_size=None)",
565577
)]
566-
fn copy(&self, filename: PathBuf, backend: &str, chunk_size: Option<usize>) -> Result<Self> {
567-
self.0.copy(filename, backend, chunk_size)
578+
fn copy(&self, filename: PathBuf, backend: &str, partial: Option<Bound<PyIterator>>, chunk_size: Option<usize>) -> Result<Self> {
579+
let partial = partial.map(|p| {
580+
p.into_iter()
581+
.map(|x| x.unwrap().extract::<String>().unwrap())
582+
.collect::<HashSet<_>>()
583+
});
584+
self.0.copy(filename, backend, partial, chunk_size)
568585
}
569586

570587
/// Return a new AnnData object with all backed arrays loaded into memory.
@@ -625,16 +642,25 @@ trait AnnDataTrait: Send + Sync + Downcast {
625642
backend: Option<&str>,
626643
) -> Result<Option<Bound<'py, PyAny>>>;
627644

628-
fn split_by(
629-
&self,
630-
key: Bound<'_, PyAny>,
631-
out_dir: PathBuf,
632-
) -> Result<HashMap<String, AnnData>>;
645+
fn split_by(&self, key: Bound<'_, PyAny>, out_dir: PathBuf)
646+
-> Result<HashMap<String, AnnData>>;
633647

634648
fn chunked_x(&self, chunk_size: usize) -> PyChunkedArray;
635649

636-
fn write(&self, filename: PathBuf, backend: &str, chunk_size: Option<usize>) -> Result<()>;
637-
fn copy(&self, filename: PathBuf, backend: &str, chunk_size: Option<usize>) -> Result<AnnData>;
650+
fn write(
651+
&self,
652+
filename: PathBuf,
653+
backend: &str,
654+
partial: Option<HashSet<String>>,
655+
chunk_size: Option<usize>,
656+
) -> Result<()>;
657+
fn copy(
658+
&self,
659+
filename: PathBuf,
660+
backend: &str,
661+
partial: Option<HashSet<String>>,
662+
chunk_size: Option<usize>,
663+
) -> Result<AnnData>;
638664
fn to_memory<'py>(&self, py: Python<'py>) -> Result<PyAnnData<'py>>;
639665

640666
fn filename(&self) -> PathBuf;
@@ -1018,21 +1044,26 @@ impl<B: Backend> AnnDataTrait for InnerAnnData<B> {
10181044
}
10191045

10201046
fn split_by(
1021-
&self,
1022-
key: Bound<'_, PyAny>,
1023-
out_dir: PathBuf,
1024-
) -> Result<HashMap<String, AnnData>> {
1047+
&self,
1048+
key: Bound<'_, PyAny>,
1049+
out_dir: PathBuf,
1050+
) -> Result<HashMap<String, AnnData>> {
10251051
let inner = self.adata.inner();
10261052

10271053
let keys = if let Ok(key) = key.extract::<String>() {
10281054
let obs = inner.read_obs()?;
1029-
obs.column(&key)?.str()?.into_iter().map(|x| x.unwrap().to_string()).collect()
1055+
obs.column(&key)?
1056+
.str()?
1057+
.into_iter()
1058+
.map(|x| x.unwrap().to_string())
1059+
.collect()
10301060
} else {
10311061
key.extract::<Vec<String>>()?
10321062
};
1033-
1063+
10341064
let split_data = inner.split_by(&keys, &out_dir)?;
1035-
split_data.into_iter()
1065+
split_data
1066+
.into_iter()
10361067
.map(|(k, v)| Ok((k, AnnData::from(v))))
10371068
.collect()
10381069
}
@@ -1041,15 +1072,30 @@ impl<B: Backend> AnnDataTrait for InnerAnnData<B> {
10411072
self.adata.inner().get_x().chunked(chunk_size).into()
10421073
}
10431074

1044-
fn write(&self, filename: PathBuf, backend: &str, chunk_size: Option<usize>) -> Result<()> {
1075+
fn write(
1076+
&self,
1077+
filename: PathBuf,
1078+
backend: &str,
1079+
partial: Option<HashSet<String>>,
1080+
chunk_size: Option<usize>,
1081+
) -> Result<()> {
10451082
match backend {
1046-
H5::NAME => self.adata.inner().write::<H5, _>(filename, chunk_size),
1083+
H5::NAME => self
1084+
.adata
1085+
.inner()
1086+
.write::<H5, _>(filename, partial, chunk_size),
10471087
x => bail!("Unsupported backend: {}", x),
10481088
}
10491089
}
10501090

1051-
fn copy(&self, filename: PathBuf, backend: &str, chunk_size: Option<usize>) -> Result<AnnData> {
1052-
AnnDataTrait::write(self, filename.clone(), backend, chunk_size)?;
1091+
fn copy(
1092+
&self,
1093+
filename: PathBuf,
1094+
backend: &str,
1095+
partial: Option<HashSet<String>>,
1096+
chunk_size: Option<usize>,
1097+
) -> Result<AnnData> {
1098+
AnnDataTrait::write(self, filename.clone(), backend, partial, chunk_size)?;
10531099
AnnData::new_from(filename, "r+", backend)
10541100
}
10551101

0 commit comments

Comments
 (0)