Skip to content

Commit ebcdb99

Browse files
committed
Implement default write configuration management and compression options in Python bindings
1 parent 541b240 commit ebcdb99

10 files changed

Lines changed: 234 additions & 28 deletions

File tree

anndata/src/backend.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ use anyhow::{bail, Result};
66
use core::fmt::{Debug, Formatter};
77
use ndarray::{arr0, Array, CowArray, Dimension, Ix0, IxDyn};
88
use std::path::{Path, PathBuf};
9+
use std::cell::RefCell;
910
pub use serde_json::Value;
1011
use serde::Deserialize;
1112

12-
#[derive(Debug, Copy, Clone)]
13+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
1314
pub enum Compression {
1415
Gzip(u8),
1516
Zst(u8),
@@ -30,6 +31,20 @@ impl Default for WriteConfig {
3031
}
3132
}
3233

34+
thread_local! {
35+
static DEFAULT_WRITE_CONFIG: RefCell<WriteConfig> = RefCell::new(WriteConfig::default());
36+
}
37+
38+
/// Set the default write configuration for all subsequent write operations.
39+
pub fn set_default_write_config(config: WriteConfig) {
40+
DEFAULT_WRITE_CONFIG.with(|c| *c.borrow_mut() = config);
41+
}
42+
43+
/// Get the current default write configuration.
44+
pub fn get_default_write_config() -> WriteConfig {
45+
DEFAULT_WRITE_CONFIG.with(|c| c.borrow().clone())
46+
}
47+
3348
pub trait Backend: 'static {
3449
/// The name of the backend.
3550
const NAME: &'static str;
@@ -119,7 +134,7 @@ pub trait GroupOp<B: Backend + ?Sized> {
119134
}
120135

121136
fn new_scalar_dataset<D: BackendData>(&self, name: &str, data: &D) -> Result<B::Dataset> {
122-
self.new_array_dataset(name, arr0(data.clone()).into(), WriteConfig::default())
137+
self.new_array_dataset(name, arr0(data.clone()).into(), get_default_write_config())
123138
}
124139
}
125140

anndata/src/data/array/chunks.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use anyhow::{Context, Result, bail};
77
use nalgebra_sparse::na::Scalar;
88
use nalgebra_sparse::{CscMatrix, CsrMatrix};
99
use ndarray::{Array, Array1, ArrayD, ArrayView1, RemoveAxis};
10+
use crate::backend::get_default_write_config;
1011

1112
pub enum MatrixBuilder<B: Backend> {
1213
CsrMatrix(CsrMatrixBuilder<B>),
@@ -125,7 +126,7 @@ impl<B: Backend> CsrMatrixBuilder<B> {
125126
self.data.finish()?;
126127
self.indptr.push(self.nnz);
127128
self.group
128-
.new_array_dataset("indptr", self.indptr.into(), Default::default())?;
129+
.new_array_dataset("indptr", self.indptr.into(), get_default_write_config())?;
129130
self.group.new_attr(
130131
"shape",
131132
[self.num_rows as u64, self.num_cols.unwrap_or(0) as u64].as_slice(),
@@ -470,7 +471,7 @@ impl<T: BackendData> ArrayChunk for CsrMatrix<T> {
470471
indices.finish()?;
471472
data.finish()?;
472473
indptr.push(nnz);
473-
group.new_array_dataset("indptr", indptr.into(), Default::default())?;
474+
group.new_array_dataset("indptr", indptr.into(), get_default_write_config())?;
474475
group.new_attr(
475476
"shape",
476477
[num_rows as u64, num_cols.unwrap_or(0) as u64].as_slice(),
@@ -600,7 +601,7 @@ impl<T: BackendData> ArrayChunk for CsrNonCanonical<T> {
600601
indices.finish()?;
601602
data.finish()?;
602603
indptr.push(nnz);
603-
group.new_array_dataset("indptr", indptr.into(), Default::default())?;
604+
group.new_array_dataset("indptr", indptr.into(), get_default_write_config())?;
604605
group.new_attr(
605606
"shape",
606607
[num_rows as u64, num_cols.unwrap_or(0) as u64].as_slice(),
@@ -734,7 +735,7 @@ impl<T: BackendData + Scalar> ArrayChunk for CscMatrix<T> {
734735
indices.finish()?;
735736
data.finish()?;
736737
indptr.push(nnz);
737-
group.create_array_data("indptr", &indptr, Default::default())?;
738+
group.create_array_data("indptr", &indptr, get_default_write_config())?;
738739
group.write_array_attr("shape", &[num_rows.unwrap_or(0), num_cols])?;
739740
Ok(DataContainer::Group(group))
740741
*/

anndata/src/data/array/dense.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use ndarray::{Array, Array1, ArrayD, ArrayView, Axis, Dimension, RemoveAxis, Sli
1616
use polars::series::Series;
1717
use std::collections::HashMap;
1818
use std::ops::Index;
19+
use crate::backend::get_default_write_config;
1920

2021
impl<'a, T: BackendData, D> Element for ArrayView<'a, T, D> {
2122
fn metadata(&self) -> MetaData {
@@ -37,7 +38,7 @@ impl<'a, T: BackendData, D: Dimension> Writable for ArrayView<'a, T, D> {
3738
location: &G,
3839
name: &str,
3940
) -> Result<DataContainer<B>> {
40-
let dataset = location.new_array_dataset(name, self.into(), Default::default())?;
41+
let dataset = location.new_array_dataset(name, self.into(), get_default_write_config())?;
4142
let mut container = DataContainer::<B>::Dataset(dataset);
4243
self.metadata().save(&mut container)?;
4344
Ok(container)

anndata/src/data/array/sparse/csc.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use nalgebra_sparse::pattern::SparsityPattern;
1515
use ndarray::Ix1;
1616

1717
use super::super::slice::SliceBounds;
18+
use crate::backend::get_default_write_config;
1819

1920
impl<T> HasShape for CscMatrix<T> {
2021
fn shape(&self) -> Shape {
@@ -219,7 +220,7 @@ impl<T: BackendData> Writable for CscMatrix<T> {
219220
let shape = self.shape();
220221

221222
self.metadata().save(&mut group)?;
222-
group.new_array_dataset("data", self.values().into(), Default::default())?;
223+
group.new_array_dataset("data", self.values().into(), get_default_write_config())?;
223224

224225
let num_rows = shape[0];
225226
// Use i32 or i64 as indices type in order to be compatible with scipy
@@ -230,15 +231,15 @@ impl<T: BackendData> Writable for CscMatrix<T> {
230231
.map(|x| (*x).try_into().ok())
231232
.collect();
232233
if let Some(indptr_i32) = try_convert_indptr {
233-
group.new_array_dataset("indptr", indptr_i32.into(), Default::default())?;
234+
group.new_array_dataset("indptr", indptr_i32.into(), get_default_write_config())?;
234235
group.new_array_dataset(
235236
"indices",
236237
self.row_indices()
237238
.iter()
238239
.map(|x| (*x) as i32)
239240
.collect::<Vec<_>>()
240241
.into(),
241-
Default::default(),
242+
get_default_write_config(),
242243
)?;
243244
} else {
244245
group.new_array_dataset(
@@ -248,7 +249,7 @@ impl<T: BackendData> Writable for CscMatrix<T> {
248249
.map(|x| TryInto::<i64>::try_into(*x).unwrap())
249250
.collect::<Vec<_>>()
250251
.into(),
251-
Default::default(),
252+
get_default_write_config(),
252253
)?;
253254
group.new_array_dataset(
254255
"indices",
@@ -257,7 +258,7 @@ impl<T: BackendData> Writable for CscMatrix<T> {
257258
.map(|x| (*x) as i64)
258259
.collect::<Vec<_>>()
259260
.into(),
260-
Default::default(),
261+
get_default_write_config(),
261262
)?;
262263
}
263264
} else if TryInto::<i64>::try_into(num_rows.saturating_sub(1)).is_ok() {
@@ -268,7 +269,7 @@ impl<T: BackendData> Writable for CscMatrix<T> {
268269
.map(|x| TryInto::<i64>::try_into(*x).unwrap())
269270
.collect::<Vec<_>>()
270271
.into(),
271-
Default::default(),
272+
get_default_write_config(),
272273
)?;
273274
group.new_array_dataset(
274275
"indices",
@@ -277,7 +278,7 @@ impl<T: BackendData> Writable for CscMatrix<T> {
277278
.map(|x| (*x) as i64)
278279
.collect::<Vec<_>>()
279280
.into(),
280-
Default::default(),
281+
get_default_write_config(),
281282
)?;
282283
} else {
283284
panic!(

anndata/src/data/array/sparse/csr.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use ndarray::{Array1, ArrayD, Ix1};
1515
use num::{NumCast, ToPrimitive};
1616

1717
use super::super::slice::SliceBounds;
18+
use crate::backend::get_default_write_config;
1819

1920
impl<T: BackendData> Element for CsrMatrix<T> {
2021
fn data_type(&self) -> DataType {
@@ -230,7 +231,7 @@ impl<T: BackendData> Writable for CsrMatrix<T> {
230231
let shape = self.shape();
231232

232233
self.metadata().save(&mut group)?;
233-
group.new_array_dataset("data", self.values().into(), Default::default())?;
234+
group.new_array_dataset("data", self.values().into(), get_default_write_config())?;
234235

235236
let num_cols = shape[1];
236237
// Use i32 or i64 as indices type in order to be compatible with scipy
@@ -241,15 +242,15 @@ impl<T: BackendData> Writable for CsrMatrix<T> {
241242
.map(|x| (*x).try_into().ok())
242243
.collect();
243244
if let Some(indptr_i32) = try_convert_indptr {
244-
group.new_array_dataset("indptr", indptr_i32.into(), Default::default())?;
245+
group.new_array_dataset("indptr", indptr_i32.into(), get_default_write_config())?;
245246
group.new_array_dataset(
246247
"indices",
247248
self.col_indices()
248249
.iter()
249250
.map(|x| (*x) as i32)
250251
.collect::<Vec<_>>()
251252
.into(),
252-
Default::default(),
253+
get_default_write_config(),
253254
)?;
254255
} else {
255256
group.new_array_dataset(
@@ -259,7 +260,7 @@ impl<T: BackendData> Writable for CsrMatrix<T> {
259260
.map(|x| TryInto::<i64>::try_into(*x).unwrap())
260261
.collect::<Vec<_>>()
261262
.into(),
262-
Default::default(),
263+
get_default_write_config(),
263264
)?;
264265
group.new_array_dataset(
265266
"indices",
@@ -268,7 +269,7 @@ impl<T: BackendData> Writable for CsrMatrix<T> {
268269
.map(|x| (*x) as i64)
269270
.collect::<Vec<_>>()
270271
.into(),
271-
Default::default(),
272+
get_default_write_config(),
272273
)?;
273274
}
274275
} else if TryInto::<i64>::try_into(num_cols.saturating_sub(1)).is_ok() {
@@ -279,7 +280,7 @@ impl<T: BackendData> Writable for CsrMatrix<T> {
279280
.map(|x| TryInto::<i64>::try_into(*x).unwrap())
280281
.collect::<Vec<_>>()
281282
.into(),
282-
Default::default(),
283+
get_default_write_config(),
283284
)?;
284285
group.new_array_dataset(
285286
"indices",
@@ -288,7 +289,7 @@ impl<T: BackendData> Writable for CsrMatrix<T> {
288289
.map(|x| (*x) as i64)
289290
.collect::<Vec<_>>()
290291
.into(),
291-
Default::default(),
292+
get_default_write_config(),
292293
)?;
293294
} else {
294295
panic!(

anndata/src/data/array/sparse/noncanonical.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use ndarray::Ix1;
1515

1616
use super::super::slice::SliceBounds;
1717
use super::DynCsrMatrix;
18+
use crate::backend::get_default_write_config;
1819

1920
#[derive(Debug, Clone, PartialEq)]
2021
pub enum DynCsrNonCanonical {
@@ -652,7 +653,7 @@ impl<T: BackendData> Writable for CsrNonCanonical<T> {
652653

653654
self.metadata().save(&mut group)?;
654655

655-
group.new_array_dataset("data", self.values().into(), Default::default())?;
656+
group.new_array_dataset("data", self.values().into(), get_default_write_config())?;
656657

657658
let num_cols = shape[1];
658659
// Use i32 or i64 as indices type in order to be compatible with scipy
@@ -663,15 +664,15 @@ impl<T: BackendData> Writable for CsrNonCanonical<T> {
663664
.map(|x| (*x).try_into().ok())
664665
.collect();
665666
if let Some(indptr_i32) = try_convert_indptr {
666-
group.new_array_dataset("indptr", indptr_i32.into(), Default::default())?;
667+
group.new_array_dataset("indptr", indptr_i32.into(), get_default_write_config())?;
667668
group.new_array_dataset(
668669
"indices",
669670
self.col_indices()
670671
.iter()
671672
.map(|x| (*x) as i32)
672673
.collect::<Vec<_>>()
673674
.into(),
674-
Default::default(),
675+
get_default_write_config(),
675676
)?;
676677
} else {
677678
group.new_array_dataset(
@@ -681,7 +682,7 @@ impl<T: BackendData> Writable for CsrNonCanonical<T> {
681682
.map(|x| TryInto::<i64>::try_into(*x).unwrap())
682683
.collect::<Vec<_>>()
683684
.into(),
684-
Default::default(),
685+
get_default_write_config(),
685686
)?;
686687
group.new_array_dataset(
687688
"indices",
@@ -690,7 +691,7 @@ impl<T: BackendData> Writable for CsrNonCanonical<T> {
690691
.map(|x| (*x) as i64)
691692
.collect::<Vec<_>>()
692693
.into(),
693-
Default::default(),
694+
get_default_write_config(),
694695
)?;
695696
}
696697
} else if TryInto::<i64>::try_into(num_cols.saturating_sub(1)).is_ok() {
@@ -701,7 +702,7 @@ impl<T: BackendData> Writable for CsrNonCanonical<T> {
701702
.map(|x| TryInto::<i64>::try_into(*x).unwrap())
702703
.collect::<Vec<_>>()
703704
.into(),
704-
Default::default(),
705+
get_default_write_config(),
705706
)?;
706707
group.new_array_dataset(
707708
"indices",
@@ -710,7 +711,7 @@ impl<T: BackendData> Writable for CsrNonCanonical<T> {
710711
.map(|x| (*x) as i64)
711712
.collect::<Vec<_>>()
712713
.into(),
713-
Default::default(),
714+
get_default_write_config(),
714715
)?;
715716
} else {
716717
panic!(

0 commit comments

Comments
 (0)