Skip to content

Commit 1a2232e

Browse files
committed
fix slicing for multi-dimensional arrays
1 parent 868d43d commit 1a2232e

3 files changed

Lines changed: 62 additions & 26 deletions

File tree

anndata-hdf5/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ homepage = "https://github.com/kaizhang/anndata-rs"
1313
[dependencies]
1414
anndata = { workspace = true }
1515
anyhow = "1.0"
16+
itertools = "0.14"
1617
hdf5 = { package = "hdf5-metno", version = "=0.10.1", features = ["blosc", "blosc-zstd"] }
1718
blosc-src = { version = "0.3.0", features = ["zstd"] }
1819
hdf5-sys = { package = "hdf5-metno-sys", version = "0.10", features = ["static", "zlib", "threadsafe"] }

anndata-hdf5/src/lib.rs

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use anndata::{
22
backend::*,
3-
data::{DynArray, DynCowArray, DynScalar, SelectInfoBounds, SelectInfoElem, SelectInfoElemBounds, Shape},
3+
data::{
4+
DynArray, DynCowArray, DynScalar, SelectInfoBounds, SelectInfoElem, SelectInfoElemBounds,
5+
Shape,
6+
},
47
};
58

69
use anyhow::{bail, Ok, Result};
@@ -10,9 +13,9 @@ use hdf5::{
1013
types::{FloatSize, TypeDescriptor, VarLenUnicode},
1114
File, Group, H5Type, Location, Selection,
1215
};
16+
use itertools::{EitherOrBoth, Itertools};
1317
use ndarray::{Array, ArrayD, ArrayView, CowArray, Dimension, IxDyn, SliceInfo, SliceInfoElem};
14-
use std::ops::Deref;
15-
use std::ops::Index;
18+
use std::ops::{Deref, Index};
1619
use std::path::{Path, PathBuf};
1720

1821
///////////////////////////////////////////////////////////////////////////////
@@ -129,7 +132,7 @@ fn new_dataset<T: BackendData>(
129132
Compression::Zst(lvl) => match dtype {
130133
ScalarType::String => builder.deflate(3),
131134
_ => builder.blosc_zstd(lvl, hdf5::filters::BloscShuffle::Byte),
132-
}
135+
},
133136
}
134137
} else {
135138
builder
@@ -306,8 +309,15 @@ impl DatasetOp<H5> for H5Dataset {
306309
let select: Vec<_> = info
307310
.as_ref()
308311
.into_iter()
309-
.zip(shape)
310-
.map(|(x, n)| SelectInfoElemBounds::new(x.as_ref(), *n))
312+
.zip_longest(shape)
313+
.map(|ty| match ty {
314+
EitherOrBoth::Both(x, n) => SelectInfoElemBounds::new(x.as_ref(), *n),
315+
EitherOrBoth::Right(n) => SelectInfoElemBounds::new(
316+
&SelectInfoElem::Slice(anndata::data::slice::SLICE_FULL),
317+
*n,
318+
),
319+
_ => panic!("inconsistent selection length"),
320+
})
311321
.collect();
312322
let new_shape = select.iter().map(|x| x.len()).collect::<Vec<_>>();
313323
ArrayD::from_shape_fn(new_shape, |idx| {
@@ -542,14 +552,25 @@ impl AttributeOp<H5> for H5Group {
542552
match value {
543553
Value::Null => Ok(()),
544554
Value::Bool(b) => write_scalar_attr(self, name, *b),
545-
Value::Number(n) => n.as_u64().map(|i| write_scalar_attr(self, name, i))
555+
Value::Number(n) => n
556+
.as_u64()
557+
.map(|i| write_scalar_attr(self, name, i))
546558
.or_else(|| n.as_i64().map(|i| write_scalar_attr(self, name, i)))
547559
.or_else(|| n.as_f64().map(|i| write_scalar_attr(self, name, i)))
548560
.expect("number cannot be converted to u64, i64 or f64"),
549561
Value::String(s) => write_scalar_attr(self, name, s.clone()),
550-
Value::Array(_) => json_to_ndarray(value, |x| x.as_i64())?.map(|x| write_array_attr(self, name, &x))
551-
.or_else(|| json_to_ndarray(value, |x| x.as_f64()).unwrap().map(|x| write_array_attr(self, name, &x)))
552-
.or_else(|| json_to_ndarray(value, |x| x.as_str().map(|s| s.to_string())).unwrap().map(|x| write_array_attr(self, name, &x)))
562+
Value::Array(_) => json_to_ndarray(value, |x| x.as_i64())?
563+
.map(|x| write_array_attr(self, name, &x))
564+
.or_else(|| {
565+
json_to_ndarray(value, |x| x.as_f64())
566+
.unwrap()
567+
.map(|x| write_array_attr(self, name, &x))
568+
})
569+
.or_else(|| {
570+
json_to_ndarray(value, |x| x.as_str().map(|s| s.to_string()))
571+
.unwrap()
572+
.map(|x| write_array_attr(self, name, &x))
573+
})
553574
.expect("array cannot be converted to i64, f64 or string"),
554575
Value::Object(_) => bail!("attributes of object type are not supported"),
555576
}
@@ -577,14 +598,25 @@ impl AttributeOp<H5> for H5Dataset {
577598
match value {
578599
Value::Null => Ok(()),
579600
Value::Bool(b) => write_scalar_attr(self, name, *b),
580-
Value::Number(n) => n.as_u64().map(|i| write_scalar_attr(self, name, i))
601+
Value::Number(n) => n
602+
.as_u64()
603+
.map(|i| write_scalar_attr(self, name, i))
581604
.or_else(|| n.as_i64().map(|i| write_scalar_attr(self, name, i)))
582605
.or_else(|| n.as_f64().map(|i| write_scalar_attr(self, name, i)))
583606
.expect("number cannot be converted to u64, i64 or f64"),
584607
Value::String(s) => write_scalar_attr(self, name, s.clone()),
585-
Value::Array(_) => json_to_ndarray(value, |x| x.as_i64())?.map(|x| write_array_attr(self, name, &x))
586-
.or_else(|| json_to_ndarray(value, |x| x.as_f64()).unwrap().map(|x| write_array_attr(self, name, &x)))
587-
.or_else(|| json_to_ndarray(value, |x| x.as_str().map(|s| s.to_string())).unwrap().map(|x| write_array_attr(self, name, &x)))
608+
Value::Array(_) => json_to_ndarray(value, |x| x.as_i64())?
609+
.map(|x| write_array_attr(self, name, &x))
610+
.or_else(|| {
611+
json_to_ndarray(value, |x| x.as_f64())
612+
.unwrap()
613+
.map(|x| write_array_attr(self, name, &x))
614+
})
615+
.or_else(|| {
616+
json_to_ndarray(value, |x| x.as_str().map(|s| s.to_string()))
617+
.unwrap()
618+
.map(|x| write_array_attr(self, name, &x))
619+
})
588620
.expect("array cannot be converted to i64, f64 or string"),
589621
Value::Object(_) => bail!("attributes of object type are not supported"),
590622
}
@@ -616,14 +648,15 @@ fn read_scalar_attr(loc: &Location, name: &str) -> Result<Value> {
616648
Ok(result)
617649
}
618650

619-
fn read_array_attr(
620-
loc: &Location,
621-
name: &str,
622-
) -> Result<Value> {
651+
fn read_array_attr(loc: &Location, name: &str) -> Result<Value> {
623652
let attr = loc.attr(name)?;
624653
let result = match attr.dtype()?.to_descriptor()? {
625-
TypeDescriptor::VarLenUnicode => ndarray_to_json(&attr.read::<VarLenUnicode, IxDyn>()?.mapv(|x| x.to_string())),
626-
TypeDescriptor::VarLenAscii => ndarray_to_json(&attr.read::<VarLenUnicode, IxDyn>()?.mapv(|x| x.to_string())),
654+
TypeDescriptor::VarLenUnicode => {
655+
ndarray_to_json(&attr.read::<VarLenUnicode, IxDyn>()?.mapv(|x| x.to_string()))
656+
}
657+
TypeDescriptor::VarLenAscii => {
658+
ndarray_to_json(&attr.read::<VarLenUnicode, IxDyn>()?.mapv(|x| x.to_string()))
659+
}
627660
TypeDescriptor::Boolean => ndarray_to_json(&attr.read::<bool, IxDyn>()?),
628661
TypeDescriptor::Unsigned(_) => ndarray_to_json(&attr.read::<u64, IxDyn>()?),
629662
TypeDescriptor::Integer(_) => ndarray_to_json(&attr.read::<i64, IxDyn>()?),
@@ -654,7 +687,8 @@ where
654687
DynCowArray::F64(x) => loc.new_attr_builder().with_data(x.view()).create(name)?,
655688
DynCowArray::Bool(x) => loc.new_attr_builder().with_data(x.view()).create(name)?,
656689
DynCowArray::String(x) => {
657-
let data: Array<VarLenUnicode, Dim> = x.map(|x| x.parse().unwrap()).into_dimensionality()?;
690+
let data: Array<VarLenUnicode, Dim> =
691+
x.map(|x| x.parse().unwrap()).into_dimensionality()?;
658692
loc.new_attr_builder().with_data(data.view()).create(name)?
659693
}
660694
};
@@ -766,9 +800,10 @@ fn ndarray_to_json<T: Into<Value> + Clone>(array: &ArrayD<T>) -> Value {
766800
vec.into()
767801
} else {
768802
// Recursive case: split along the first axis and apply recursively
769-
let nested_vec = array.outer_iter().map(|sub_array| {
770-
recursive_convert(&sub_array.to_owned().into_dyn())
771-
}).collect::<Vec<Value>>();
803+
let nested_vec = array
804+
.outer_iter()
805+
.map(|sub_array| recursive_convert(&sub_array.to_owned().into_dyn()))
806+
.collect::<Vec<Value>>();
772807
Value::Array(nested_vec)
773808
}
774809
}

anndata/src/backend.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pub struct WriteConfig {
2424
impl Default for WriteConfig {
2525
fn default() -> Self {
2626
Self {
27-
compression: Some(Compression::Zst(3)),
27+
compression: Some(Compression::Zst(5)),
2828
block_size: None,
2929
}
3030
}
@@ -338,7 +338,7 @@ impl<B: Backend> DataContainer<B> {
338338
"dataframe" => DataType::DataFrame,
339339
"mapping" | "dict" => DataType::Mapping,
340340
"nullable-integer" | "nullable-boolean" => DataType::NullableArray,
341-
ty => bail!("Unsupported type '{}'", ty),
341+
ty => bail!("the anndata file contains an unsupported encoding type: '{}'", ty),
342342
};
343343
Ok(ty)
344344
}

0 commit comments

Comments
 (0)