Skip to content

Commit 01346dd

Browse files
authored
Change ArrayV3Metadata.data_type to DataType (#2278)
* change v3.metadata.data_type type * implement suggestions
1 parent 8dd1f24 commit 01346dd

File tree

2 files changed

+42
-31
lines changed

2 files changed

+42
-31
lines changed

Diff for: src/zarr/core/metadata/v3.py

+37-25
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
if TYPE_CHECKING:
77
from typing import Self
88

9-
import numpy.typing as npt
10-
119
from zarr.core.buffer import Buffer, BufferPrototype
1210
from zarr.core.chunk_grids import ChunkGrid
1311
from zarr.core.common import JSON, ChunkCoords
@@ -20,6 +18,7 @@
2018

2119
import numcodecs.abc
2220
import numpy as np
21+
import numpy.typing as npt
2322

2423
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
2524
from zarr.core.array_spec import ArraySpec
@@ -31,6 +30,8 @@
3130
from zarr.core.metadata.common import ArrayMetadata, parse_attributes
3231
from zarr.registry import get_codec_class
3332

33+
DEFAULT_DTYPE = "float64"
34+
3435

3536
def parse_zarr_format(data: object) -> Literal[3]:
3637
if data == 3:
@@ -152,7 +153,7 @@ def _replace_special_floats(obj: object) -> Any:
152153
@dataclass(frozen=True, kw_only=True)
153154
class ArrayV3Metadata(ArrayMetadata):
154155
shape: ChunkCoords
155-
data_type: np.dtype[Any]
156+
data_type: DataType
156157
chunk_grid: ChunkGrid
157158
chunk_key_encoding: ChunkKeyEncoding
158159
fill_value: Any
@@ -167,7 +168,7 @@ def __init__(
167168
self,
168169
*,
169170
shape: Iterable[int],
170-
data_type: npt.DTypeLike,
171+
data_type: npt.DTypeLike | DataType,
171172
chunk_grid: dict[str, JSON] | ChunkGrid,
172173
chunk_key_encoding: dict[str, JSON] | ChunkKeyEncoding,
173174
fill_value: Any,
@@ -180,18 +181,18 @@ def __init__(
180181
Because the class is a frozen dataclass, we set attributes using object.__setattr__
181182
"""
182183
shape_parsed = parse_shapelike(shape)
183-
data_type_parsed = parse_dtype(data_type)
184+
data_type_parsed = DataType.parse(data_type)
184185
chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid)
185186
chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding)
186187
dimension_names_parsed = parse_dimension_names(dimension_names)
187-
fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed)
188+
fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed.to_numpy())
188189
attributes_parsed = parse_attributes(attributes)
189190
codecs_parsed_partial = parse_codecs(codecs)
190191
storage_transformers_parsed = parse_storage_transformers(storage_transformers)
191192

192193
array_spec = ArraySpec(
193194
shape=shape_parsed,
194-
dtype=data_type_parsed,
195+
dtype=data_type_parsed.to_numpy(),
195196
fill_value=fill_value_parsed,
196197
order="C", # TODO: order is not needed here.
197198
prototype=default_buffer_prototype(), # TODO: prototype is not needed here.
@@ -224,11 +225,14 @@ def _validate_metadata(self) -> None:
224225
if self.fill_value is None:
225226
raise ValueError("`fill_value` is required.")
226227
for codec in self.codecs:
227-
codec.validate(shape=self.shape, dtype=self.data_type, chunk_grid=self.chunk_grid)
228+
codec.validate(
229+
shape=self.shape, dtype=self.data_type.to_numpy(), chunk_grid=self.chunk_grid
230+
)
228231

229232
@property
230233
def dtype(self) -> np.dtype[Any]:
231-
return self.data_type
234+
"""Interpret Zarr dtype as NumPy dtype"""
235+
return self.data_type.to_numpy()
232236

233237
@property
234238
def ndim(self) -> int:
@@ -266,13 +270,13 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
266270
_ = parse_node_type_array(_data.pop("node_type"))
267271

268272
# check that the data_type attribute is valid
269-
_ = DataType(_data["data_type"])
273+
data_type = DataType.parse(_data.pop("data_type"))
270274

271275
# dimension_names key is optional, normalize missing to `None`
272276
_data["dimension_names"] = _data.pop("dimension_names", None)
273277
# attributes key is optional, normalize missing to `None`
274278
_data["attributes"] = _data.pop("attributes", None)
275-
return cls(**_data) # type: ignore[arg-type]
279+
return cls(**_data, data_type=data_type) # type: ignore[arg-type]
276280

277281
def to_dict(self) -> dict[str, JSON]:
278282
out_dict = super().to_dict()
@@ -490,8 +494,11 @@ def to_numpy_shortname(self) -> str:
490494
}
491495
return data_type_to_numpy[self]
492496

497+
def to_numpy(self) -> np.dtype[Any]:
498+
return np.dtype(self.to_numpy_shortname())
499+
493500
@classmethod
494-
def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
501+
def from_numpy(cls, dtype: np.dtype[Any]) -> DataType:
495502
dtype_to_data_type = {
496503
"|b1": "bool",
497504
"bool": "bool",
@@ -511,16 +518,21 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
511518
}
512519
return DataType[dtype_to_data_type[dtype.str]]
513520

514-
515-
def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
516-
try:
517-
dtype = np.dtype(data)
518-
except (ValueError, TypeError) as e:
519-
raise ValueError(f"Invalid V3 data_type: {data}") from e
520-
# check that this is a valid v3 data_type
521-
try:
522-
_ = DataType.from_dtype(dtype)
523-
except KeyError as e:
524-
raise ValueError(f"Invalid V3 data_type: {dtype}") from e
525-
526-
return dtype
521+
@classmethod
522+
def parse(cls, dtype: None | DataType | Any) -> DataType:
523+
if dtype is None:
524+
# the default dtype
525+
return DataType[DEFAULT_DTYPE]
526+
if isinstance(dtype, DataType):
527+
return dtype
528+
else:
529+
try:
530+
dtype = np.dtype(dtype)
531+
except (ValueError, TypeError) as e:
532+
raise ValueError(f"Invalid V3 data_type: {dtype}") from e
533+
# check that this is a valid v3 data_type
534+
try:
535+
data_type = DataType.from_numpy(dtype)
536+
except KeyError as e:
537+
raise ValueError(f"Invalid V3 data_type: {dtype}") from e
538+
return data_type

Diff for: tests/v3/test_metadata/test_v3.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from zarr.codecs.bytes import BytesCodec
88
from zarr.core.buffer import default_buffer_prototype
99
from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding
10-
from zarr.core.metadata.v3 import ArrayV3Metadata
10+
from zarr.core.metadata.v3 import ArrayV3Metadata, DataType
1111

1212
if TYPE_CHECKING:
1313
from collections.abc import Sequence
@@ -22,7 +22,6 @@
2222

2323
from zarr.core.metadata.v3 import (
2424
parse_dimension_names,
25-
parse_dtype,
2625
parse_fill_value,
2726
parse_zarr_format,
2827
)
@@ -209,7 +208,7 @@ def test_metadata_to_dict(
209208
storage_transformers: None | tuple[dict[str, JSON]],
210209
) -> None:
211210
shape = (1, 2, 3)
212-
data_type = "uint8"
211+
data_type = DataType.uint8
213212
if chunk_grid == "regular":
214213
cgrid = {"name": "regular", "configuration": {"chunk_shape": (1, 1, 1)}}
215214

@@ -290,7 +289,7 @@ def test_metadata_to_dict(
290289
# assert result["fill_value"] == fill_value
291290

292291

293-
async def test_invalid_dtype_raises() -> None:
292+
def test_invalid_dtype_raises() -> None:
294293
metadata_dict = {
295294
"zarr_format": 3,
296295
"node_type": "array",
@@ -301,14 +300,14 @@ async def test_invalid_dtype_raises() -> None:
301300
"codecs": (),
302301
"fill_value": np.datetime64(0, "ns"),
303302
}
304-
with pytest.raises(ValueError, match=r".* is not a valid DataType"):
303+
with pytest.raises(ValueError, match=r"Invalid V3 data_type: .*"):
305304
ArrayV3Metadata.from_dict(metadata_dict)
306305

307306

308307
@pytest.mark.parametrize("data", ["datetime64[s]", "foo", object()])
309308
def test_parse_invalid_dtype_raises(data):
310309
with pytest.raises(ValueError, match=r"Invalid V3 data_type: .*"):
311-
parse_dtype(data)
310+
DataType.parse(data)
312311

313312

314313
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)