Skip to content

Commit dabb30c

Browse files
ilan-goldLDeakin
andauthored
(fix): cache partial decoder (#93)
* (fix): cache partial decoder * (refactor): local caching * (refactor): only use one hashmap * (refactor): only do unique on filtered object * Update src/lib.rs Co-authored-by: Lachlan Deakin <[email protected]> * Update src/lib.rs Co-authored-by: Lachlan Deakin <[email protected]> * Update src/lib.rs Co-authored-by: Lachlan Deakin <[email protected]> --------- Co-authored-by: Lachlan Deakin <[email protected]>
1 parent 23bef57 commit dabb30c

File tree

3 files changed

+52
-10
lines changed

3 files changed

+52
-10
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ opendal = { version = "0.51.0", features = ["services-http"] }
2323
tokio = { version = "1.41.1", features = ["rt-multi-thread"] }
2424
zarrs_opendal = "0.5.0"
2525
zarrs_metadata = "0.3.3" # require recent zarr-python compatibility fixes (remove with zarrs 0.20)
26+
itertools = "0.9.0"
2627

2728
[profile.release]
2829
lto = true

src/lib.rs

+44-10
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
#![allow(clippy::module_name_repetitions)]
33

44
use std::borrow::Cow;
5+
use std::collections::HashMap;
56
use std::ptr::NonNull;
67
use std::sync::Arc;
78

9+
use chunk_item::WithSubset;
10+
use itertools::Itertools;
811
use numpy::npyffi::PyArrayObject;
912
use numpy::{PyArrayDescrMethods, PyUntypedArray, PyUntypedArrayMethods};
1013
use pyo3::exceptions::{PyRuntimeError, PyTypeError, PyValueError};
@@ -14,12 +17,16 @@ use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
1417
use rayon::iter::{IntoParallelIterator, ParallelIterator};
1518
use rayon_iter_concurrent_limit::iter_concurrent_limit;
1619
use unsafe_cell_slice::UnsafeCellSlice;
17-
use zarrs::array::codec::{ArrayToBytesCodecTraits, CodecOptions, CodecOptionsBuilder};
20+
use utils::is_whole_chunk;
21+
use zarrs::array::codec::{
22+
ArrayPartialDecoderTraits, ArrayToBytesCodecTraits, CodecOptions, CodecOptionsBuilder,
23+
};
1824
use zarrs::array::{
1925
copy_fill_value_into, update_array_bytes, ArrayBytes, ArraySize, CodecChain, FillValue,
2026
};
2127
use zarrs::array_subset::ArraySubset;
2228
use zarrs::metadata::v3::MetadataV3;
29+
use zarrs::storage::StoreKey;
2330

2431
mod chunk_item;
2532
mod concurrency;
@@ -265,15 +272,44 @@ impl CodecPipelineImpl {
265272
return Ok(());
266273
};
267274

275+
// Assemble partial decoders ahead of time and in parallel
276+
let partial_chunk_descriptions = chunk_descriptions
277+
.iter()
278+
.filter(|item| !(is_whole_chunk(item)))
279+
.unique_by(|item| item.key())
280+
.collect::<Vec<_>>();
281+
let mut partial_decoder_cache: HashMap<StoreKey, Arc<dyn ArrayPartialDecoderTraits>> =
282+
HashMap::new().into();
283+
if partial_chunk_descriptions.len() > 0 {
284+
let key_decoder_pairs = iter_concurrent_limit!(
285+
chunk_concurrent_limit,
286+
partial_chunk_descriptions,
287+
map,
288+
|item| {
289+
let input_handle = self.stores.decoder(item)?;
290+
let partial_decoder = self
291+
.codec_chain
292+
.clone()
293+
.partial_decoder(
294+
Arc::new(input_handle),
295+
item.representation(),
296+
&codec_options,
297+
)
298+
.map_py_err::<PyValueError>()?;
299+
Ok((item.key().clone(), partial_decoder))
300+
}
301+
)
302+
.collect::<PyResult<Vec<_>>>()?;
303+
partial_decoder_cache.extend(key_decoder_pairs);
304+
}
305+
268306
py.allow_threads(move || {
269307
// FIXME: the `decode_into` methods only support fixed length data types.
270308
// For variable length data types, need a codepath with non `_into` methods.
271309
// Collect all the subsets and copy into value on the Python side?
272310
let update_chunk_subset = |item: chunk_item::WithSubset| {
273311
// See zarrs::array::Array::retrieve_chunk_subset_into
274-
if item.chunk_subset.start().iter().all(|&o| o == 0)
275-
&& item.chunk_subset.shape() == item.representation().shape_u64()
276-
{
312+
if is_whole_chunk(&item) {
277313
// See zarrs::array::Array::retrieve_chunk_into
278314
if let Some(chunk_encoded) = self.stores.get(&item)? {
279315
// Decode the encoded data into the output buffer
@@ -308,12 +344,10 @@ impl CodecPipelineImpl {
308344
}
309345
}
310346
} else {
311-
let input_handle = Arc::new(self.stores.decoder(&item)?);
312-
let partial_decoder = self
313-
.codec_chain
314-
.clone()
315-
.partial_decoder(input_handle, item.representation(), &codec_options)
316-
.map_py_err::<PyValueError>()?;
347+
let key = item.key();
348+
let partial_decoder = partial_decoder_cache.get(key).ok_or_else(|| {
349+
PyRuntimeError::new_err(format!("Partial decoder not found for key: {key}"))
350+
})?;
317351
unsafe {
318352
// SAFETY:
319353
// - output is an array with output_shape elements of the item.representation data type,

src/utils.rs

+7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use std::fmt::Display;
33
use numpy::{PyUntypedArray, PyUntypedArrayMethods};
44
use pyo3::{Bound, PyErr, PyResult, PyTypeInfo};
55

6+
use crate::{ChunksItem, WithSubset};
7+
68
pub(crate) trait PyErrExt<T> {
79
fn map_py_err<PE: PyTypeInfo>(self) -> PyResult<T>;
810
}
@@ -29,3 +31,8 @@ impl PyUntypedArrayExt for Bound<'_, PyUntypedArray> {
2931
})
3032
}
3133
}
34+
35+
pub fn is_whole_chunk(item: &WithSubset) -> bool {
36+
item.chunk_subset.start().iter().all(|&o| o == 0)
37+
&& item.chunk_subset.shape() == item.representation().shape_u64()
38+
}

0 commit comments

Comments
 (0)