Skip to content

Commit 0bfa3bc

Browse files
fix: bring back zarr (#21)
Co-authored-by: Philipp A. <flying-sheep@web.de>
1 parent cf0b312 commit 0bfa3bc

5 files changed

Lines changed: 113 additions & 52 deletions

File tree

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
members = [
33
"anndata",
44
"anndata-hdf5",
5+
"anndata-zarr",
56
"pyanndata",
67
"anndata-test-utils",
78
"python",
@@ -11,4 +12,5 @@ resolver = "2"
1112
[workspace.dependencies]
1213
anndata = { path = "anndata" }
1314
anndata-hdf5 = { path = "anndata-hdf5" }
15+
anndata-zarr = { path = "anndata-zarr" }
1416
pyanndata = { path = "pyanndata" }

anndata-test-utils/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ itertools = "0.14"
1818

1919
[dev-dependencies]
2020
anndata-hdf5 = { workspace = true }
21+
anndata-zarr = { workspace = true }
2122
tempfile = "3.2"
2223
proptest = "1"
2324
bincode = { version = "2", features = ["serde"] }

anndata-test-utils/tests/tests.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
use anndata_test_utils as utils;
22
use anndata_test_utils::with_tmp_dir;
33
use anndata_hdf5::H5;
4+
use anndata_zarr::Zarr;
45
use anndata::{AnnData, Backend};
56

67
#[test]
78
fn test_basic() {
89
utils::test_basic::<H5>();
10+
utils::test_basic::<Zarr>();
911
}
1012

1113
#[test]
@@ -15,13 +17,16 @@ fn test_complex_dataframe() {
1517
let file = dir.join("test.h5");
1618
let adata = AnnData::<H5>::open(H5::open(&input).unwrap()).unwrap();
1719
adata.write::<H5, _>(file, None, None).unwrap();
20+
21+
let file = dir.join("test.zarr");
22+
adata.write::<Zarr, _>(file, None, None).unwrap();
1823
})
1924
}
2025

2126
#[test]
2227
fn test_save() {
2328
utils::test_save::<H5>();
24-
//utils::test_save::<Zarr>();
29+
utils::test_save::<Zarr>();
2530
}
2631

2732
#[test]
@@ -31,11 +36,10 @@ fn test_speacial_cases() {
3136
let adata_gen = || AnnData::<H5>::new(&file).unwrap();
3237
utils::test_speacial_cases(|| adata_gen());
3338

34-
/*
3539
let file = dir.join("test.zarr");
3640
let adata_gen = || AnnData::<Zarr>::new(&file).unwrap();
3741
utils::test_speacial_cases(|| adata_gen());
38-
*/
42+
3943
})
4044
}
4145

@@ -45,6 +49,10 @@ fn test_noncanonical() {
4549
let file = dir.join("test.h5");
4650
let adata_gen = || AnnData::<H5>::new(&file).unwrap();
4751
utils::test_noncanonical(|| adata_gen());
52+
53+
let file = dir.join("test.zarr");
54+
let adata_gen = || AnnData::<Zarr>::new(&file).unwrap();
55+
utils::test_noncanonical(|| adata_gen());
4856
})
4957
}
5058

@@ -80,5 +88,9 @@ fn test_iterator() {
8088
let file = dir.join("test.h5");
8189
let adata_gen = || AnnData::<H5>::new(&file).unwrap();
8290
utils::test_iterator(|| adata_gen());
91+
92+
let file = dir.join("test.zarr");
93+
let adata_gen = || AnnData::<Zarr>::new(&file).unwrap();
94+
utils::test_iterator(|| adata_gen());
8395
})
8496
}

anndata-zarr/Cargo.toml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "anndata-zarr"
3-
version = "0.1.0"
3+
version = "0.2.0"
44
edition = "2021"
55
rust-version = "1.75"
66
authors = ["Kai Zhang <kai@kzhang.org>"]
@@ -14,6 +14,12 @@ homepage = "https://github.com/kaizhang/anndata-rs"
1414
anndata = { workspace = true }
1515
serde_json = "1.0"
1616
anyhow = "1.0"
17-
ndarray = { version = "0.16", features = ["serde"] }
18-
zarrs = "0.21"
19-
smallvec = "1.15"
17+
ndarray = { version = "0.17", features = ["serde"] }
18+
zarrs = "0.23"
19+
smallvec = "1.15"
20+
21+
[dev-dependencies]
22+
tempfile = "3.2"
23+
proptest = "1"
24+
rand = "0.9"
25+
ndarray-rand = "0.16"

anndata-zarr/src/lib.rs

Lines changed: 85 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,24 @@ use anyhow::{bail, Context, Result};
77
use ndarray::{Array, ArrayD, ArrayView, CowArray, Dimension, IxDyn, SliceInfoElem};
88
use std::{
99
borrow::Cow,
10+
num::NonZeroU64,
1011
ops::{Deref, Index},
1112
path::{Path, PathBuf},
1213
};
1314
use std::{sync::Arc, vec};
14-
use zarrs::array::codec::bytes_to_bytes::zstd::ZstdCodec;
15+
use zarrs::array::{
16+
ZARR_NAN_F32, ZARR_NAN_F64, codec::bytes_to_bytes::zstd::ZstdCodec, data_type::{
17+
BoolDataType, Float32DataType, Float64DataType, Int8DataType, Int16DataType, Int32DataType, Int64DataType, StringDataType, UInt8DataType, UInt16DataType, UInt32DataType, UInt64DataType
18+
}
19+
};
1520
use zarrs::filesystem::FilesystemStore;
1621
use zarrs::group::Group;
1722
use zarrs::{array::ElementOwned, storage::ReadableWritableListableStorageTraits};
1823
use zarrs::{
19-
array::{codec::ShardingCodecBuilder, data_type::DataType, ArrayShardedReadableExt, Element},
20-
array_subset::ArraySubset,
24+
array::{
25+
codec::ShardingCodecBuilder, data_type, ArrayShardedReadableExt, ArraySubset, Element,
26+
FillValue,
27+
},
2128
storage::StorePrefix,
2229
};
2330

@@ -344,20 +351,32 @@ impl AttributeOp<Zarr> for ZarrDataset {
344351

345352
impl DatasetOp<Zarr> for ZarrDataset {
346353
fn dtype(&self) -> Result<ScalarType> {
347-
match self.dataset.data_type() {
348-
DataType::UInt8 => Ok(ScalarType::U8),
349-
DataType::UInt16 => Ok(ScalarType::U16),
350-
DataType::UInt32 => Ok(ScalarType::U32),
351-
DataType::UInt64 => Ok(ScalarType::U64),
352-
DataType::Int8 => Ok(ScalarType::I8),
353-
DataType::Int16 => Ok(ScalarType::I16),
354-
DataType::Int32 => Ok(ScalarType::I32),
355-
DataType::Int64 => Ok(ScalarType::I64),
356-
DataType::Float32 => Ok(ScalarType::F32),
357-
DataType::Float64 => Ok(ScalarType::F64),
358-
DataType::Bool => Ok(ScalarType::Bool),
359-
DataType::String => Ok(ScalarType::String),
360-
ty => bail!("Unsupported type: {:?}", ty),
354+
if self.dataset.data_type().is::<UInt8DataType>() {
355+
Ok(ScalarType::U8)
356+
} else if self.dataset.data_type().is::<UInt16DataType>() {
357+
Ok(ScalarType::U16)
358+
} else if self.dataset.data_type().is::<UInt32DataType>() {
359+
Ok(ScalarType::U32)
360+
} else if self.dataset.data_type().is::<UInt64DataType>() {
361+
Ok(ScalarType::U64)
362+
} else if self.dataset.data_type().is::<Int8DataType>() {
363+
Ok(ScalarType::I8)
364+
} else if self.dataset.data_type().is::<Int16DataType>() {
365+
Ok(ScalarType::I16)
366+
} else if self.dataset.data_type().is::<Int32DataType>() {
367+
Ok(ScalarType::I32)
368+
} else if self.dataset.data_type().is::<Int64DataType>() {
369+
Ok(ScalarType::I64)
370+
} else if self.dataset.data_type().is::<Float32DataType>() {
371+
Ok(ScalarType::F32)
372+
} else if self.dataset.data_type().is::<Float64DataType>() {
373+
Ok(ScalarType::F64)
374+
} else if self.dataset.data_type().is::<BoolDataType>() {
375+
Ok(ScalarType::Bool)
376+
} else if self.dataset.data_type().is::<StringDataType>() {
377+
Ok(ScalarType::String)
378+
} else {
379+
bail!("Unsupported type: {:?}", self.dataset.data_type())
361380
}
362381
}
363382

@@ -371,8 +390,11 @@ impl DatasetOp<Zarr> for ZarrDataset {
371390

372391
fn reshape(&mut self, shape: &Shape) -> Result<()> {
373392
self.dataset
374-
.set_shape(shape.as_ref().iter().map(|x| *x as u64).collect());
393+
.set_shape(shape.as_ref().iter().map(|x| *x as u64).collect())?;
375394
self.dataset.store_metadata()?;
395+
// The intenion of the caching API is no mutation after creation:
396+
// https://ossci.zulipchat.com/#narrow/channel/423692-Zarr/topic/zarrs.20.60ArrayShardedReadableExtCache.60.20.2B.20.60set_shape.60/with/595519775
397+
self.cache.clear();
376398
Ok(())
377399
}
378400

@@ -392,21 +414,21 @@ impl DatasetOp<Zarr> for ZarrDataset {
392414
if let Some(subset) = to_array_subset(sel) {
393415
let arr = dataset
394416
.dataset
395-
.retrieve_array_subset_ndarray_sharded_opt(
417+
.retrieve_array_subset_sharded_opt::<ndarray::ArrayD<T>>(
396418
&dataset.cache,
397419
&subset,
398-
&zarrs::array::codec::CodecOptions::default(),
420+
&zarrs::array::CodecOptions::default(),
399421
)?
400422
.into_dimensionality::<D>()?;
401423
Ok(arr)
402424
} else {
403425
// Read the entire array and then select the slice.
404426
let arr = dataset
405427
.dataset
406-
.retrieve_array_subset_ndarray_sharded_opt(
428+
.retrieve_array_subset_sharded_opt::<ndarray::ArrayD<T>>(
407429
&dataset.cache,
408430
&dataset.dataset.subset_all(),
409-
&zarrs::array::codec::CodecOptions::default(),
431+
&zarrs::array::CodecOptions::default(),
410432
)?
411433
.into_dimensionality::<D>()?;
412434
Ok(select(arr.view(), selection))
@@ -461,9 +483,10 @@ impl DatasetOp<Zarr> for ZarrDataset {
461483
})
462484
.collect();
463485
if starts.len() == selection.ndim() {
486+
container.cache.clear();
464487
container
465488
.dataset
466-
.store_array_subset_ndarray(starts.as_slice(), arr.into_owned())?;
489+
.store_array_subset(&ArraySubset::new_with_start_shape(starts, arr.shape().iter().map(|x| *x as u64).collect())?, arr.to_owned())?;
467490
} else {
468491
panic!("Not implemented");
469492
}
@@ -567,23 +590,27 @@ fn new_empty_dataset_helper<T: BackendData, S: ?Sized>(
567590
config: WriteConfig,
568591
) -> Result<zarrs::array::Array<S>> {
569592
let (datatype, fill) = match T::DTYPE {
570-
ScalarType::U8 => (DataType::UInt8, 0u8.into()),
571-
ScalarType::U16 => (DataType::UInt16, 0u16.into()),
572-
ScalarType::U32 => (DataType::UInt32, 0u32.into()),
573-
ScalarType::U64 => (DataType::UInt64, 0u64.into()),
574-
ScalarType::I8 => (DataType::Int8, 0i8.into()),
575-
ScalarType::I16 => (DataType::Int16, 0i16.into()),
576-
ScalarType::I32 => (DataType::Int32, 0i32.into()),
577-
ScalarType::I64 => (DataType::Int64, 0i64.into()),
578-
ScalarType::F32 => (DataType::Float32, zarrs::array::ZARR_NAN_F32.into()),
579-
ScalarType::F64 => (DataType::Float64, zarrs::array::ZARR_NAN_F64.into()),
580-
ScalarType::Bool => (DataType::Bool, false.into()),
581-
ScalarType::String => (DataType::String, "".into()),
593+
ScalarType::U8 => (data_type::uint8(), FillValue::from(0u8)),
594+
ScalarType::U16 => (data_type::uint16(), FillValue::from(0u16)),
595+
ScalarType::U32 => (data_type::uint32(), FillValue::from(0u32)),
596+
ScalarType::U64 => (data_type::uint64(), FillValue::from(0u64)),
597+
ScalarType::I8 => (data_type::int8(), FillValue::from(0i8)),
598+
ScalarType::I16 => (data_type::int16(), FillValue::from(0i16)),
599+
ScalarType::I32 => (data_type::int32(), FillValue::from(0i32)),
600+
ScalarType::I64 => (data_type::int64(), FillValue::from(0i64)),
601+
ScalarType::F32 => (data_type::float32(), FillValue::from(ZARR_NAN_F32)),
602+
ScalarType::F64 => (data_type::float64(), FillValue::from(ZARR_NAN_F64)),
603+
ScalarType::Bool => (data_type::bool(), FillValue::from(false)),
604+
ScalarType::String => (data_type::string(), FillValue::from("")),
582605
};
583606

584607
let shape = shape.as_ref();
585608
let chunk_size: Vec<u64> = match config.block_size {
586-
Some(s) => s.as_ref().into_iter().map(|x| (*x).max(1) as u64).collect(),
609+
Some(s) => s
610+
.as_ref()
611+
.into_iter()
612+
.map(|x| (*x).max(1) as u64)
613+
.collect::<Vec<_>>(),
587614
_ => {
588615
if shape.len() == 1 {
589616
vec![shape[0].min(16384).max(1) as u64]
@@ -594,29 +621,35 @@ fn new_empty_dataset_helper<T: BackendData, S: ?Sized>(
594621
};
595622

596623
let mut use_sharding = true;
597-
if matches!(datatype, DataType::String) {//|| shape.iter().sum::<usize>() == 0 {
624+
if datatype == data_type::string() {
625+
//|| shape.iter().sum::<usize>() == 0 {
598626
// Strings are not sharded, they are stored as a single chunk.
599627
use_sharding = false;
600628
}
601629

602630
let array = if use_sharding {
603631
let shard_shape = chunk_size.iter().map(|&x| x * 8).collect::<Vec<_>>();
604-
let mut sharding_codec_builder =
605-
ShardingCodecBuilder::new(chunk_size.try_into()?);
632+
let mut sharding_codec_builder = ShardingCodecBuilder::new(
633+
chunk_size
634+
.iter()
635+
.map(|e| NonZeroU64::try_from(*e))
636+
.collect::<Result<Vec<NonZeroU64>, _>>()?,
637+
&datatype,
638+
);
606639
sharding_codec_builder.bytes_to_bytes_codecs(vec![Arc::new(ZstdCodec::new(7, false))]);
607640
zarrs::array::ArrayBuilder::new(
608-
shape.iter().map(|x| *x as u64).collect(),
641+
shape.iter().map(|x| *x as u64).collect::<Vec<_>>(),
642+
shard_shape.as_slice(),
609643
datatype,
610-
shard_shape.try_into()?,
611644
fill,
612645
)
613646
.array_to_bytes_codec(sharding_codec_builder.build_arc())
614647
.build(store, path)?
615648
} else {
616649
zarrs::array::ArrayBuilder::new(
617-
shape.iter().map(|x| *x as u64).collect(),
650+
shape.iter().map(|x| *x as u64).collect::<Vec<_>>(),
651+
chunk_size.as_slice(),
618652
datatype,
619-
chunk_size.try_into()?,
620653
fill,
621654
)
622655
.bytes_to_bytes_codecs(vec![Arc::new(ZstdCodec::new(7, false))])
@@ -710,15 +743,22 @@ mod tests {
710743
let mut dataset =
711744
group.new_empty_dataset::<i32>("test", &[20, 50].as_slice().into(), config)?;
712745

713-
let arr = Array::random((10, 10), Uniform::new(0, 100));
746+
// Repeated writes force cache clearance
747+
let arr: ndarray::Array2<i32> = Array::random((10, 10), Uniform::new(0, 100).unwrap());
748+
dataset.write_array_slice(arr.view().into(), s![5..15, 10..20].as_ref())?;
749+
assert_eq!(
750+
arr,
751+
dataset.read_array_slice::<i32, _, _>(s![5..15, 10..20].as_ref())?
752+
);
753+
let arr: ndarray::Array2<i32> = Array::random((10, 10), Uniform::new(0, 100).unwrap());
714754
dataset.write_array_slice(arr.view().into(), s![5..15, 10..20].as_ref())?;
715755
assert_eq!(
716756
arr,
717757
dataset.read_array_slice::<i32, _, _>(s![5..15, 10..20].as_ref())?
718758
);
719759

720760
// Repeatitive writes
721-
let arr = Array::random((20, 50), Uniform::new(0, 100));
761+
let arr = Array::random((20, 50), Uniform::new(0, 100).unwrap());
722762
dataset.write_array_slice(arr.view().into(), s![.., ..].as_ref())?;
723763
dataset.write_array_slice(arr.view().into(), s![.., ..].as_ref())?;
724764

0 commit comments

Comments
 (0)