6
6
if TYPE_CHECKING :
7
7
from typing import Self
8
8
9
- import numpy .typing as npt
10
-
11
9
from zarr .core .buffer import Buffer , BufferPrototype
12
10
from zarr .core .chunk_grids import ChunkGrid
13
11
from zarr .core .common import JSON , ChunkCoords
20
18
21
19
import numcodecs .abc
22
20
import numpy as np
21
+ import numpy .typing as npt
23
22
24
23
from zarr .abc .codec import ArrayArrayCodec , ArrayBytesCodec , BytesBytesCodec , Codec
25
24
from zarr .core .array_spec import ArraySpec
31
30
from zarr .core .metadata .common import ArrayMetadata , parse_attributes
32
31
from zarr .registry import get_codec_class
33
32
33
+ DEFAULT_DTYPE = "float64"
34
+
34
35
35
36
def parse_zarr_format (data : object ) -> Literal [3 ]:
36
37
if data == 3 :
@@ -152,7 +153,7 @@ def _replace_special_floats(obj: object) -> Any:
152
153
@dataclass (frozen = True , kw_only = True )
153
154
class ArrayV3Metadata (ArrayMetadata ):
154
155
shape : ChunkCoords
155
- data_type : np . dtype [ Any ]
156
+ data_type : DataType
156
157
chunk_grid : ChunkGrid
157
158
chunk_key_encoding : ChunkKeyEncoding
158
159
fill_value : Any
@@ -167,7 +168,7 @@ def __init__(
167
168
self ,
168
169
* ,
169
170
shape : Iterable [int ],
170
- data_type : npt .DTypeLike ,
171
+ data_type : npt .DTypeLike | DataType ,
171
172
chunk_grid : dict [str , JSON ] | ChunkGrid ,
172
173
chunk_key_encoding : dict [str , JSON ] | ChunkKeyEncoding ,
173
174
fill_value : Any ,
@@ -180,18 +181,18 @@ def __init__(
180
181
Because the class is a frozen dataclass, we set attributes using object.__setattr__
181
182
"""
182
183
shape_parsed = parse_shapelike (shape )
183
- data_type_parsed = parse_dtype (data_type )
184
+ data_type_parsed = DataType . parse (data_type )
184
185
chunk_grid_parsed = ChunkGrid .from_dict (chunk_grid )
185
186
chunk_key_encoding_parsed = ChunkKeyEncoding .from_dict (chunk_key_encoding )
186
187
dimension_names_parsed = parse_dimension_names (dimension_names )
187
- fill_value_parsed = parse_fill_value (fill_value , dtype = data_type_parsed )
188
+ fill_value_parsed = parse_fill_value (fill_value , dtype = data_type_parsed . to_numpy () )
188
189
attributes_parsed = parse_attributes (attributes )
189
190
codecs_parsed_partial = parse_codecs (codecs )
190
191
storage_transformers_parsed = parse_storage_transformers (storage_transformers )
191
192
192
193
array_spec = ArraySpec (
193
194
shape = shape_parsed ,
194
- dtype = data_type_parsed ,
195
+ dtype = data_type_parsed . to_numpy () ,
195
196
fill_value = fill_value_parsed ,
196
197
order = "C" , # TODO: order is not needed here.
197
198
prototype = default_buffer_prototype (), # TODO: prototype is not needed here.
@@ -224,11 +225,14 @@ def _validate_metadata(self) -> None:
224
225
if self .fill_value is None :
225
226
raise ValueError ("`fill_value` is required." )
226
227
for codec in self .codecs :
227
- codec .validate (shape = self .shape , dtype = self .data_type , chunk_grid = self .chunk_grid )
228
+ codec .validate (
229
+ shape = self .shape , dtype = self .data_type .to_numpy (), chunk_grid = self .chunk_grid
230
+ )
228
231
229
232
@property
230
233
def dtype (self ) -> np .dtype [Any ]:
231
- return self .data_type
234
+ """Interpret Zarr dtype as NumPy dtype"""
235
+ return self .data_type .to_numpy ()
232
236
233
237
@property
234
238
def ndim (self ) -> int :
@@ -266,13 +270,13 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
266
270
_ = parse_node_type_array (_data .pop ("node_type" ))
267
271
268
272
# check that the data_type attribute is valid
269
- _ = DataType (_data [ "data_type" ] )
273
+ data_type = DataType . parse (_data . pop ( "data_type" ) )
270
274
271
275
# dimension_names key is optional, normalize missing to `None`
272
276
_data ["dimension_names" ] = _data .pop ("dimension_names" , None )
273
277
# attributes key is optional, normalize missing to `None`
274
278
_data ["attributes" ] = _data .pop ("attributes" , None )
275
- return cls (** _data ) # type: ignore[arg-type]
279
+ return cls (** _data , data_type = data_type ) # type: ignore[arg-type]
276
280
277
281
def to_dict (self ) -> dict [str , JSON ]:
278
282
out_dict = super ().to_dict ()
@@ -490,8 +494,11 @@ def to_numpy_shortname(self) -> str:
490
494
}
491
495
return data_type_to_numpy [self ]
492
496
497
+ def to_numpy (self ) -> np .dtype [Any ]:
498
+ return np .dtype (self .to_numpy_shortname ())
499
+
493
500
@classmethod
494
- def from_dtype (cls , dtype : np .dtype [Any ]) -> DataType :
501
+ def from_numpy (cls , dtype : np .dtype [Any ]) -> DataType :
495
502
dtype_to_data_type = {
496
503
"|b1" : "bool" ,
497
504
"bool" : "bool" ,
@@ -511,16 +518,21 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
511
518
}
512
519
return DataType [dtype_to_data_type [dtype .str ]]
513
520
514
-
515
- def parse_dtype (data : npt .DTypeLike ) -> np .dtype [Any ]:
516
- try :
517
- dtype = np .dtype (data )
518
- except (ValueError , TypeError ) as e :
519
- raise ValueError (f"Invalid V3 data_type: { data } " ) from e
520
- # check that this is a valid v3 data_type
521
- try :
522
- _ = DataType .from_dtype (dtype )
523
- except KeyError as e :
524
- raise ValueError (f"Invalid V3 data_type: { dtype } " ) from e
525
-
526
- return dtype
521
+ @classmethod
522
+ def parse (cls , dtype : None | DataType | Any ) -> DataType :
523
+ if dtype is None :
524
+ # the default dtype
525
+ return DataType [DEFAULT_DTYPE ]
526
+ if isinstance (dtype , DataType ):
527
+ return dtype
528
+ else :
529
+ try :
530
+ dtype = np .dtype (dtype )
531
+ except (ValueError , TypeError ) as e :
532
+ raise ValueError (f"Invalid V3 data_type: { dtype } " ) from e
533
+ # check that this is a valid v3 data_type
534
+ try :
535
+ data_type = DataType .from_numpy (dtype )
536
+ except KeyError as e :
537
+ raise ValueError (f"Invalid V3 data_type: { dtype } " ) from e
538
+ return data_type
0 commit comments