@@ -7,17 +7,24 @@ use anyhow::{bail, Context, Result};
77use ndarray:: { Array , ArrayD , ArrayView , CowArray , Dimension , IxDyn , SliceInfoElem } ;
88use std:: {
99 borrow:: Cow ,
10+ num:: NonZeroU64 ,
1011 ops:: { Deref , Index } ,
1112 path:: { Path , PathBuf } ,
1213} ;
1314use std:: { sync:: Arc , vec} ;
14- use zarrs:: array:: codec:: bytes_to_bytes:: zstd:: ZstdCodec ;
15+ use zarrs:: array:: {
16+ ZARR_NAN_F32 , ZARR_NAN_F64 , codec:: bytes_to_bytes:: zstd:: ZstdCodec , data_type:: {
17+ BoolDataType , Float32DataType , Float64DataType , Int8DataType , Int16DataType , Int32DataType , Int64DataType , StringDataType , UInt8DataType , UInt16DataType , UInt32DataType , UInt64DataType
18+ }
19+ } ;
1520use zarrs:: filesystem:: FilesystemStore ;
1621use zarrs:: group:: Group ;
1722use zarrs:: { array:: ElementOwned , storage:: ReadableWritableListableStorageTraits } ;
1823use zarrs:: {
19- array:: { codec:: ShardingCodecBuilder , data_type:: DataType , ArrayShardedReadableExt , Element } ,
20- array_subset:: ArraySubset ,
24+ array:: {
25+ codec:: ShardingCodecBuilder , data_type, ArrayShardedReadableExt , ArraySubset , Element ,
26+ FillValue ,
27+ } ,
2128 storage:: StorePrefix ,
2229} ;
2330
@@ -344,20 +351,32 @@ impl AttributeOp<Zarr> for ZarrDataset {
344351
345352impl DatasetOp < Zarr > for ZarrDataset {
346353 fn dtype ( & self ) -> Result < ScalarType > {
347- match self . dataset . data_type ( ) {
348- DataType :: UInt8 => Ok ( ScalarType :: U8 ) ,
349- DataType :: UInt16 => Ok ( ScalarType :: U16 ) ,
350- DataType :: UInt32 => Ok ( ScalarType :: U32 ) ,
351- DataType :: UInt64 => Ok ( ScalarType :: U64 ) ,
352- DataType :: Int8 => Ok ( ScalarType :: I8 ) ,
353- DataType :: Int16 => Ok ( ScalarType :: I16 ) ,
354- DataType :: Int32 => Ok ( ScalarType :: I32 ) ,
355- DataType :: Int64 => Ok ( ScalarType :: I64 ) ,
356- DataType :: Float32 => Ok ( ScalarType :: F32 ) ,
357- DataType :: Float64 => Ok ( ScalarType :: F64 ) ,
358- DataType :: Bool => Ok ( ScalarType :: Bool ) ,
359- DataType :: String => Ok ( ScalarType :: String ) ,
360- ty => bail ! ( "Unsupported type: {:?}" , ty) ,
354+ if self . dataset . data_type ( ) . is :: < UInt8DataType > ( ) {
355+ Ok ( ScalarType :: U8 )
356+ } else if self . dataset . data_type ( ) . is :: < UInt16DataType > ( ) {
357+ Ok ( ScalarType :: U16 )
358+ } else if self . dataset . data_type ( ) . is :: < UInt32DataType > ( ) {
359+ Ok ( ScalarType :: U32 )
360+ } else if self . dataset . data_type ( ) . is :: < UInt64DataType > ( ) {
361+ Ok ( ScalarType :: U64 )
362+ } else if self . dataset . data_type ( ) . is :: < Int8DataType > ( ) {
363+ Ok ( ScalarType :: I8 )
364+ } else if self . dataset . data_type ( ) . is :: < Int16DataType > ( ) {
365+ Ok ( ScalarType :: I16 )
366+ } else if self . dataset . data_type ( ) . is :: < Int32DataType > ( ) {
367+ Ok ( ScalarType :: I32 )
368+ } else if self . dataset . data_type ( ) . is :: < Int64DataType > ( ) {
369+ Ok ( ScalarType :: I64 )
370+ } else if self . dataset . data_type ( ) . is :: < Float32DataType > ( ) {
371+ Ok ( ScalarType :: F32 )
372+ } else if self . dataset . data_type ( ) . is :: < Float64DataType > ( ) {
373+ Ok ( ScalarType :: F64 )
374+ } else if self . dataset . data_type ( ) . is :: < BoolDataType > ( ) {
375+ Ok ( ScalarType :: Bool )
376+ } else if self . dataset . data_type ( ) . is :: < StringDataType > ( ) {
377+ Ok ( ScalarType :: String )
378+ } else {
379+ bail ! ( "Unsupported type: {:?}" , self . dataset. data_type( ) )
361380 }
362381 }
363382
@@ -371,8 +390,11 @@ impl DatasetOp<Zarr> for ZarrDataset {
371390
372391 fn reshape ( & mut self , shape : & Shape ) -> Result < ( ) > {
373392 self . dataset
374- . set_shape ( shape. as_ref ( ) . iter ( ) . map ( |x| * x as u64 ) . collect ( ) ) ;
393+ . set_shape ( shape. as_ref ( ) . iter ( ) . map ( |x| * x as u64 ) . collect ( ) ) ? ;
375394 self . dataset . store_metadata ( ) ?;
395+ // The intenion of the caching API is no mutation after creation:
396+ // https://ossci.zulipchat.com/#narrow/channel/423692-Zarr/topic/zarrs.20.60ArrayShardedReadableExtCache.60.20.2B.20.60set_shape.60/with/595519775
397+ self . cache . clear ( ) ;
376398 Ok ( ( ) )
377399 }
378400
@@ -392,21 +414,21 @@ impl DatasetOp<Zarr> for ZarrDataset {
392414 if let Some ( subset) = to_array_subset ( sel) {
393415 let arr = dataset
394416 . dataset
395- . retrieve_array_subset_ndarray_sharded_opt (
417+ . retrieve_array_subset_sharded_opt :: < ndarray :: ArrayD < T > > (
396418 & dataset. cache ,
397419 & subset,
398- & zarrs:: array:: codec :: CodecOptions :: default ( ) ,
420+ & zarrs:: array:: CodecOptions :: default ( ) ,
399421 ) ?
400422 . into_dimensionality :: < D > ( ) ?;
401423 Ok ( arr)
402424 } else {
403425 // Read the entire array and then select the slice.
404426 let arr = dataset
405427 . dataset
406- . retrieve_array_subset_ndarray_sharded_opt (
428+ . retrieve_array_subset_sharded_opt :: < ndarray :: ArrayD < T > > (
407429 & dataset. cache ,
408430 & dataset. dataset . subset_all ( ) ,
409- & zarrs:: array:: codec :: CodecOptions :: default ( ) ,
431+ & zarrs:: array:: CodecOptions :: default ( ) ,
410432 ) ?
411433 . into_dimensionality :: < D > ( ) ?;
412434 Ok ( select ( arr. view ( ) , selection) )
@@ -461,9 +483,10 @@ impl DatasetOp<Zarr> for ZarrDataset {
461483 } )
462484 . collect ( ) ;
463485 if starts. len ( ) == selection. ndim ( ) {
486+ container. cache . clear ( ) ;
464487 container
465488 . dataset
466- . store_array_subset_ndarray ( starts. as_slice ( ) , arr. into_owned ( ) ) ?;
489+ . store_array_subset ( & ArraySubset :: new_with_start_shape ( starts, arr . shape ( ) . iter ( ) . map ( |x| * x as u64 ) . collect ( ) ) ? , arr. to_owned ( ) ) ?;
467490 } else {
468491 panic ! ( "Not implemented" ) ;
469492 }
@@ -567,23 +590,27 @@ fn new_empty_dataset_helper<T: BackendData, S: ?Sized>(
567590 config : WriteConfig ,
568591) -> Result < zarrs:: array:: Array < S > > {
569592 let ( datatype, fill) = match T :: DTYPE {
570- ScalarType :: U8 => ( DataType :: UInt8 , 0u8 . into ( ) ) ,
571- ScalarType :: U16 => ( DataType :: UInt16 , 0u16 . into ( ) ) ,
572- ScalarType :: U32 => ( DataType :: UInt32 , 0u32 . into ( ) ) ,
573- ScalarType :: U64 => ( DataType :: UInt64 , 0u64 . into ( ) ) ,
574- ScalarType :: I8 => ( DataType :: Int8 , 0i8 . into ( ) ) ,
575- ScalarType :: I16 => ( DataType :: Int16 , 0i16 . into ( ) ) ,
576- ScalarType :: I32 => ( DataType :: Int32 , 0i32 . into ( ) ) ,
577- ScalarType :: I64 => ( DataType :: Int64 , 0i64 . into ( ) ) ,
578- ScalarType :: F32 => ( DataType :: Float32 , zarrs :: array :: ZARR_NAN_F32 . into ( ) ) ,
579- ScalarType :: F64 => ( DataType :: Float64 , zarrs :: array :: ZARR_NAN_F64 . into ( ) ) ,
580- ScalarType :: Bool => ( DataType :: Bool , false . into ( ) ) ,
581- ScalarType :: String => ( DataType :: String , "" . into ( ) ) ,
593+ ScalarType :: U8 => ( data_type :: uint8 ( ) , FillValue :: from ( 0u8 ) ) ,
594+ ScalarType :: U16 => ( data_type :: uint16 ( ) , FillValue :: from ( 0u16 ) ) ,
595+ ScalarType :: U32 => ( data_type :: uint32 ( ) , FillValue :: from ( 0u32 ) ) ,
596+ ScalarType :: U64 => ( data_type :: uint64 ( ) , FillValue :: from ( 0u64 ) ) ,
597+ ScalarType :: I8 => ( data_type :: int8 ( ) , FillValue :: from ( 0i8 ) ) ,
598+ ScalarType :: I16 => ( data_type :: int16 ( ) , FillValue :: from ( 0i16 ) ) ,
599+ ScalarType :: I32 => ( data_type :: int32 ( ) , FillValue :: from ( 0i32 ) ) ,
600+ ScalarType :: I64 => ( data_type :: int64 ( ) , FillValue :: from ( 0i64 ) ) ,
601+ ScalarType :: F32 => ( data_type :: float32 ( ) , FillValue :: from ( ZARR_NAN_F32 ) ) ,
602+ ScalarType :: F64 => ( data_type :: float64 ( ) , FillValue :: from ( ZARR_NAN_F64 ) ) ,
603+ ScalarType :: Bool => ( data_type :: bool ( ) , FillValue :: from ( false ) ) ,
604+ ScalarType :: String => ( data_type :: string ( ) , FillValue :: from ( "" ) ) ,
582605 } ;
583606
584607 let shape = shape. as_ref ( ) ;
585608 let chunk_size: Vec < u64 > = match config. block_size {
586- Some ( s) => s. as_ref ( ) . into_iter ( ) . map ( |x| ( * x) . max ( 1 ) as u64 ) . collect ( ) ,
609+ Some ( s) => s
610+ . as_ref ( )
611+ . into_iter ( )
612+ . map ( |x| ( * x) . max ( 1 ) as u64 )
613+ . collect :: < Vec < _ > > ( ) ,
587614 _ => {
588615 if shape. len ( ) == 1 {
589616 vec ! [ shape[ 0 ] . min( 16384 ) . max( 1 ) as u64 ]
@@ -594,29 +621,35 @@ fn new_empty_dataset_helper<T: BackendData, S: ?Sized>(
594621 } ;
595622
596623 let mut use_sharding = true ;
597- if matches ! ( datatype, DataType :: String ) { //|| shape.iter().sum::<usize>() == 0 {
624+ if datatype == data_type:: string ( ) {
625+ //|| shape.iter().sum::<usize>() == 0 {
598626 // Strings are not sharded, they are stored as a single chunk.
599627 use_sharding = false ;
600628 }
601629
602630 let array = if use_sharding {
603631 let shard_shape = chunk_size. iter ( ) . map ( |& x| x * 8 ) . collect :: < Vec < _ > > ( ) ;
604- let mut sharding_codec_builder =
605- ShardingCodecBuilder :: new ( chunk_size. try_into ( ) ?) ;
632+ let mut sharding_codec_builder = ShardingCodecBuilder :: new (
633+ chunk_size
634+ . iter ( )
635+ . map ( |e| NonZeroU64 :: try_from ( * e) )
636+ . collect :: < Result < Vec < NonZeroU64 > , _ > > ( ) ?,
637+ & datatype,
638+ ) ;
606639 sharding_codec_builder. bytes_to_bytes_codecs ( vec ! [ Arc :: new( ZstdCodec :: new( 7 , false ) ) ] ) ;
607640 zarrs:: array:: ArrayBuilder :: new (
608- shape. iter ( ) . map ( |x| * x as u64 ) . collect ( ) ,
641+ shape. iter ( ) . map ( |x| * x as u64 ) . collect :: < Vec < _ > > ( ) ,
642+ shard_shape. as_slice ( ) ,
609643 datatype,
610- shard_shape. try_into ( ) ?,
611644 fill,
612645 )
613646 . array_to_bytes_codec ( sharding_codec_builder. build_arc ( ) )
614647 . build ( store, path) ?
615648 } else {
616649 zarrs:: array:: ArrayBuilder :: new (
617- shape. iter ( ) . map ( |x| * x as u64 ) . collect ( ) ,
650+ shape. iter ( ) . map ( |x| * x as u64 ) . collect :: < Vec < _ > > ( ) ,
651+ chunk_size. as_slice ( ) ,
618652 datatype,
619- chunk_size. try_into ( ) ?,
620653 fill,
621654 )
622655 . bytes_to_bytes_codecs ( vec ! [ Arc :: new( ZstdCodec :: new( 7 , false ) ) ] )
@@ -710,15 +743,22 @@ mod tests {
710743 let mut dataset =
711744 group. new_empty_dataset :: < i32 > ( "test" , & [ 20 , 50 ] . as_slice ( ) . into ( ) , config) ?;
712745
713- let arr = Array :: random ( ( 10 , 10 ) , Uniform :: new ( 0 , 100 ) ) ;
746+ // Repeated writes force cache clearance
747+ let arr: ndarray:: Array2 < i32 > = Array :: random ( ( 10 , 10 ) , Uniform :: new ( 0 , 100 ) . unwrap ( ) ) ;
748+ dataset. write_array_slice ( arr. view ( ) . into ( ) , s ! [ 5 ..15 , 10 ..20 ] . as_ref ( ) ) ?;
749+ assert_eq ! (
750+ arr,
751+ dataset. read_array_slice:: <i32 , _, _>( s![ 5 ..15 , 10 ..20 ] . as_ref( ) ) ?
752+ ) ;
753+ let arr: ndarray:: Array2 < i32 > = Array :: random ( ( 10 , 10 ) , Uniform :: new ( 0 , 100 ) . unwrap ( ) ) ;
714754 dataset. write_array_slice ( arr. view ( ) . into ( ) , s ! [ 5 ..15 , 10 ..20 ] . as_ref ( ) ) ?;
715755 assert_eq ! (
716756 arr,
717757 dataset. read_array_slice:: <i32 , _, _>( s![ 5 ..15 , 10 ..20 ] . as_ref( ) ) ?
718758 ) ;
719759
720760 // Repeatitive writes
721- let arr = Array :: random ( ( 20 , 50 ) , Uniform :: new ( 0 , 100 ) ) ;
761+ let arr = Array :: random ( ( 20 , 50 ) , Uniform :: new ( 0 , 100 ) . unwrap ( ) ) ;
722762 dataset. write_array_slice ( arr. view ( ) . into ( ) , s ! [ .., ..] . as_ref ( ) ) ?;
723763 dataset. write_array_slice ( arr. view ( ) . into ( ) , s ! [ .., ..] . as_ref ( ) ) ?;
724764
0 commit comments