diff --git a/Cargo.toml b/Cargo.toml index 26c459f..94c33bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ opendal = { version = "0.51.0", features = ["services-http"] } tokio = { version = "1.41.1", features = ["rt-multi-thread"] } zarrs_opendal = "0.5.0" zarrs_metadata = "0.3.3" # require recent zarr-python compatibility fixes (remove with zarrs 0.20) +itertools = "0.9.0" [profile.release] lto = true diff --git a/src/lib.rs b/src/lib.rs index cae087c..a38be9b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,9 +2,12 @@ #![allow(clippy::module_name_repetitions)] use std::borrow::Cow; +use std::collections::HashMap; use std::ptr::NonNull; use std::sync::Arc; +use chunk_item::WithSubset; +use itertools::Itertools; use numpy::npyffi::PyArrayObject; use numpy::{PyArrayDescrMethods, PyUntypedArray, PyUntypedArrayMethods}; use pyo3::exceptions::{PyRuntimeError, PyTypeError, PyValueError}; @@ -14,12 +17,16 @@ use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use rayon_iter_concurrent_limit::iter_concurrent_limit; use unsafe_cell_slice::UnsafeCellSlice; -use zarrs::array::codec::{ArrayToBytesCodecTraits, CodecOptions, CodecOptionsBuilder}; +use utils::is_whole_chunk; +use zarrs::array::codec::{ + ArrayPartialDecoderTraits, ArrayToBytesCodecTraits, CodecOptions, CodecOptionsBuilder, +}; use zarrs::array::{ copy_fill_value_into, update_array_bytes, ArrayBytes, ArraySize, CodecChain, FillValue, }; use zarrs::array_subset::ArraySubset; use zarrs::metadata::v3::MetadataV3; +use zarrs::storage::StoreKey; mod chunk_item; mod concurrency; @@ -265,15 +272,44 @@ impl CodecPipelineImpl { return Ok(()); }; + // Assemble partial decoders ahead of time and in parallel + let partial_chunk_descriptions = chunk_descriptions + .iter() + .filter(|item| !(is_whole_chunk(item))) + .unique_by(|item| item.key()) + .collect::>(); + let mut partial_decoder_cache: HashMap> = + HashMap::new().into(); + if partial_chunk_descriptions.len() > 0 { + let key_decoder_pairs = iter_concurrent_limit!( + chunk_concurrent_limit, + partial_chunk_descriptions, + map, + |item| { + let input_handle = self.stores.decoder(item)?; + let partial_decoder = self + .codec_chain + .clone() + .partial_decoder( + Arc::new(input_handle), + item.representation(), + &codec_options, + ) + .map_py_err::()?; + Ok((item.key().clone(), partial_decoder)) + } + ) + .collect::>>()?; + partial_decoder_cache.extend(key_decoder_pairs); + } + py.allow_threads(move || { // FIXME: the `decode_into` methods only support fixed length data types. // For variable length data types, need a codepath with non `_into` methods. // Collect all the subsets and copy into value on the Python side? let update_chunk_subset = |item: chunk_item::WithSubset| { // See zarrs::array::Array::retrieve_chunk_subset_into - if item.chunk_subset.start().iter().all(|&o| o == 0) - && item.chunk_subset.shape() == item.representation().shape_u64() - { + if is_whole_chunk(&item) { // See zarrs::array::Array::retrieve_chunk_into if let Some(chunk_encoded) = self.stores.get(&item)? { // Decode the encoded data into the output buffer @@ -308,12 +344,10 @@ impl CodecPipelineImpl { } } } else { - let input_handle = Arc::new(self.stores.decoder(&item)?); - let partial_decoder = self - .codec_chain - .clone() - .partial_decoder(input_handle, item.representation(), &codec_options) - .map_py_err::()?; + let key = item.key(); + let partial_decoder = partial_decoder_cache.get(key).ok_or_else(|| { + PyRuntimeError::new_err(format!("Partial decoder not found for key: {key}")) + })?; unsafe { // SAFETY: // - output is an array with output_shape elements of the item.representation data type, diff --git a/src/utils.rs b/src/utils.rs index 5855b54..b33b4b1 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -3,6 +3,8 @@ use std::fmt::Display; use numpy::{PyUntypedArray, PyUntypedArrayMethods}; use pyo3::{Bound, PyErr, PyResult, PyTypeInfo}; +use crate::{ChunksItem, WithSubset}; + pub(crate) trait PyErrExt { fn map_py_err(self) -> PyResult; } @@ -29,3 +31,8 @@ impl PyUntypedArrayExt for Bound<'_, PyUntypedArray> { }) } } + +pub fn is_whole_chunk(item: &WithSubset) -> bool { + item.chunk_subset.start().iter().all(|&o| o == 0) + && item.chunk_subset.shape() == item.representation().shape_u64() +}