Skip to content

(fix): cache partial decoder #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 10, 2025
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ opendal = { version = "0.51.0", features = ["services-http"] }
tokio = { version = "1.41.1", features = ["rt-multi-thread"] }
zarrs_opendal = "0.5.0"
zarrs_metadata = "0.3.3" # require recent zarr-python compatibility fixes (remove with zarrs 0.20)
itertools = "0.9.0"

[profile.release]
lto = true
57 changes: 46 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
#![allow(clippy::module_name_repetitions)]

use std::borrow::Cow;
use std::collections::HashMap;
use std::ptr::NonNull;
use std::sync::Arc;

use chunk_item::WithSubset;
use itertools::Itertools;
use numpy::npyffi::PyArrayObject;
use numpy::{PyArrayDescrMethods, PyUntypedArray, PyUntypedArrayMethods};
use pyo3::exceptions::{PyRuntimeError, PyTypeError, PyValueError};
Expand All @@ -14,12 +17,16 @@ use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon_iter_concurrent_limit::iter_concurrent_limit;
use unsafe_cell_slice::UnsafeCellSlice;
use zarrs::array::codec::{ArrayToBytesCodecTraits, CodecOptions, CodecOptionsBuilder};
use utils::is_whole_chunk;
use zarrs::array::codec::{
ArrayPartialDecoderTraits, ArrayToBytesCodecTraits, CodecOptions, CodecOptionsBuilder,
};
use zarrs::array::{
copy_fill_value_into, update_array_bytes, ArrayBytes, ArraySize, CodecChain, FillValue,
};
use zarrs::array_subset::ArraySubset;
use zarrs::metadata::v3::MetadataV3;
use zarrs::storage::StoreKey;

mod chunk_item;
mod concurrency;
Expand Down Expand Up @@ -265,15 +272,41 @@ impl CodecPipelineImpl {
return Ok(());
};

// Assemble partial decoders ahead of time and in parallel
let partial_chunk_descriptions = chunk_descriptions
.iter()
.filter(|item| !(is_whole_chunk(item)))
.unique_by(|item| item.key())
.collect::<Vec<_>>();
let mut partial_decoder_cache: HashMap<StoreKey, Arc<dyn ArrayPartialDecoderTraits>> =
HashMap::new().into();
if partial_chunk_descriptions.len() > 0 {
let key_decoder_pairs = partial_chunk_descriptions
.into_par_iter()
.map(|item| {
let input_handle = self.stores.decoder(item)?;
let partial_decoder = self
.codec_chain
.clone()
.partial_decoder(
Arc::new(input_handle),
item.representation(),
&codec_options,
)
.map_py_err::<PyValueError>()?;
Ok((item.key().clone(), partial_decoder))
})
.collect::<PyResult<Vec<_>>>()?;
partial_decoder_cache.extend(key_decoder_pairs);
}

py.allow_threads(move || {
// FIXME: the `decode_into` methods only support fixed length data types.
// For variable length data types, need a codepath with non `_into` methods.
// Collect all the subsets and copy into value on the Python side?
let update_chunk_subset = |item: chunk_item::WithSubset| {
// See zarrs::array::Array::retrieve_chunk_subset_into
if item.chunk_subset.start().iter().all(|&o| o == 0)
&& item.chunk_subset.shape() == item.representation().shape_u64()
{
if is_whole_chunk(&item) {
// See zarrs::array::Array::retrieve_chunk_into
if let Some(chunk_encoded) = self.stores.get(&item)? {
// Decode the encoded data into the output buffer
Expand Down Expand Up @@ -308,18 +341,20 @@ impl CodecPipelineImpl {
}
}
} else {
let input_handle = Arc::new(self.stores.decoder(&item)?);
let partial_decoder = self
.codec_chain
.clone()
.partial_decoder(input_handle, item.representation(), &codec_options)
.map_py_err::<PyValueError>()?;
let key = item.key();
let partial_decoder: PyResult<&Arc<dyn ArrayPartialDecoderTraits>> =
match partial_decoder_cache.get(key) {
Some(e) => Ok(e),
None => Err(PyRuntimeError::new_err(format!(
"Partial decoder not found for key: {key}"
))),
};
unsafe {
// SAFETY:
// - output is an array with output_shape elements of the item.representation data type,
// - item.subset is within the bounds of output_shape.
// - item.chunk_subset has the same number of elements as item.subset.
partial_decoder.partial_decode_into(
partial_decoder?.partial_decode_into(
&item.chunk_subset,
&output,
&output_shape,
Expand Down
7 changes: 7 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::fmt::Display;
use numpy::{PyUntypedArray, PyUntypedArrayMethods};
use pyo3::{Bound, PyErr, PyResult, PyTypeInfo};

use crate::{ChunksItem, WithSubset};

pub(crate) trait PyErrExt<T> {
fn map_py_err<PE: PyTypeInfo>(self) -> PyResult<T>;
}
Expand All @@ -29,3 +31,8 @@ impl PyUntypedArrayExt for Bound<'_, PyUntypedArray> {
})
}
}

pub fn is_whole_chunk(item: &WithSubset) -> bool {
item.chunk_subset.start().iter().all(|&o| o == 0)
&& item.chunk_subset.shape() == item.representation().shape_u64()
}
Loading