Skip to content

Commit 0328656

Browse files
maxrjonesnormanrz
andauthored
Use dataclasses for ByteRangeRequests (#2585)
* Use TypedDicts for more literate ByteRangeRequests * Update utility function * fixes sharding * Ignore mypy errors * Fix offset in _normalize_byte_range_index * Update get_partial_values for FsspecStore * Re-add fs._cat_ranges argument * Simplify typing * Update _normalize to return start, stop * Use explicit range * Use dataclasses * Update typing * Update docstring * Rename ExplicitRange to ExplicitByteRequest * Rename OffsetRange to OffsetByteRequest * Rename SuffixRange to SuffixByteRequest * Use match; case instead of if; elif * Revert "Use match; case instead of if; elif" This reverts commit a7d35f8. * Update ByteRangeRequest to ByteRequest * Remove ByteRange definition from common * Rename ExplicitByteRequest to RangeByteRequest * Provide more informative error message --------- Co-authored-by: Norman Rzepka <[email protected]>
1 parent 22ebded commit 0328656

File tree

15 files changed

+221
-145
lines changed

15 files changed

+221
-145
lines changed

Diff for: src/zarr/abc/store.py

+42-8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from abc import ABC, abstractmethod
44
from asyncio import gather
5+
from dataclasses import dataclass
56
from itertools import starmap
67
from typing import TYPE_CHECKING, Protocol, runtime_checkable
78

@@ -19,7 +20,34 @@
1920

2021
__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"]
2122

22-
ByteRangeRequest: TypeAlias = tuple[int | None, int | None]
23+
24+
@dataclass
25+
class RangeByteRequest:
26+
"""Request a specific byte range"""
27+
28+
start: int
29+
"""The start of the byte range request (inclusive)."""
30+
end: int
31+
"""The end of the byte range request (exclusive)."""
32+
33+
34+
@dataclass
35+
class OffsetByteRequest:
36+
"""Request all bytes starting from a given byte offset"""
37+
38+
offset: int
39+
"""The byte offset for the offset range request."""
40+
41+
42+
@dataclass
43+
class SuffixByteRequest:
44+
"""Request up to the last `n` bytes"""
45+
46+
suffix: int
47+
"""The number of bytes from the suffix to request."""
48+
49+
50+
ByteRequest: TypeAlias = RangeByteRequest | OffsetByteRequest | SuffixByteRequest
2351

2452

2553
class Store(ABC):
@@ -141,14 +169,20 @@ async def get(
141169
self,
142170
key: str,
143171
prototype: BufferPrototype,
144-
byte_range: ByteRangeRequest | None = None,
172+
byte_range: ByteRequest | None = None,
145173
) -> Buffer | None:
146174
"""Retrieve the value associated with a given key.
147175
148176
Parameters
149177
----------
150178
key : str
151-
byte_range : tuple[int | None, int | None], optional
179+
byte_range : ByteRequest, optional
180+
181+
ByteRequest may be one of the following. If not provided, all data associated with the key is retrieved.
182+
183+
- RangeByteRequest(int, int): Request a specific range of bytes in the form (start, end). The end is exclusive. If the given range is zero-length or starts after the end of the object, an error will be returned. Additionally, if the range ends after the end of the object, the entire remainder of the object will be returned. Otherwise, the exact requested range will be returned.
184+
- OffsetByteRequest(int): Request all bytes starting from a given byte offset. This is equivalent to bytes={int}- as an HTTP header.
185+
- SuffixByteRequest(int): Request the last int bytes. Note that here, int is the size of the request, not the byte offset. This is equivalent to bytes=-{int} as an HTTP header.
152186
153187
Returns
154188
-------
@@ -160,7 +194,7 @@ async def get(
160194
async def get_partial_values(
161195
self,
162196
prototype: BufferPrototype,
163-
key_ranges: Iterable[tuple[str, ByteRangeRequest]],
197+
key_ranges: Iterable[tuple[str, ByteRequest | None]],
164198
) -> list[Buffer | None]:
165199
"""Retrieve possibly partial values from given key_ranges.
166200
@@ -338,7 +372,7 @@ def close(self) -> None:
338372
self._is_open = False
339373

340374
async def _get_many(
341-
self, requests: Iterable[tuple[str, BufferPrototype, ByteRangeRequest | None]]
375+
self, requests: Iterable[tuple[str, BufferPrototype, ByteRequest | None]]
342376
) -> AsyncGenerator[tuple[str, Buffer | None], None]:
343377
"""
344378
Retrieve a collection of objects from storage. In general this method does not guarantee
@@ -416,17 +450,17 @@ async def getsize_prefix(self, prefix: str) -> int:
416450
@runtime_checkable
417451
class ByteGetter(Protocol):
418452
async def get(
419-
self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None
453+
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
420454
) -> Buffer | None: ...
421455

422456

423457
@runtime_checkable
424458
class ByteSetter(Protocol):
425459
async def get(
426-
self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None
460+
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
427461
) -> Buffer | None: ...
428462

429-
async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None: ...
463+
async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None: ...
430464

431465
async def delete(self) -> None: ...
432466

Diff for: src/zarr/codecs/sharding.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717
Codec,
1818
CodecPipeline,
1919
)
20-
from zarr.abc.store import ByteGetter, ByteRangeRequest, ByteSetter
20+
from zarr.abc.store import (
21+
ByteGetter,
22+
ByteRequest,
23+
ByteSetter,
24+
RangeByteRequest,
25+
SuffixByteRequest,
26+
)
2127
from zarr.codecs.bytes import BytesCodec
2228
from zarr.codecs.crc32c_ import Crc32cCodec
2329
from zarr.core.array_spec import ArrayConfig, ArraySpec
@@ -77,7 +83,7 @@ class _ShardingByteGetter(ByteGetter):
7783
chunk_coords: ChunkCoords
7884

7985
async def get(
80-
self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None
86+
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
8187
) -> Buffer | None:
8288
assert byte_range is None, "byte_range is not supported within shards"
8389
assert (
@@ -90,7 +96,7 @@ async def get(
9096
class _ShardingByteSetter(_ShardingByteGetter, ByteSetter):
9197
shard_dict: ShardMutableMapping
9298

93-
async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None:
99+
async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None:
94100
assert byte_range is None, "byte_range is not supported within shards"
95101
self.shard_dict[self.chunk_coords] = value
96102

@@ -129,7 +135,7 @@ def get_chunk_slice(self, chunk_coords: ChunkCoords) -> tuple[int, int] | None:
129135
if (chunk_start, chunk_len) == (MAX_UINT_64, MAX_UINT_64):
130136
return None
131137
else:
132-
return (int(chunk_start), int(chunk_len))
138+
return (int(chunk_start), int(chunk_start + chunk_len))
133139

134140
def set_chunk_slice(self, chunk_coords: ChunkCoords, chunk_slice: slice | None) -> None:
135141
localized_chunk = self._localize_chunk(chunk_coords)
@@ -203,7 +209,7 @@ def create_empty(
203209
def __getitem__(self, chunk_coords: ChunkCoords) -> Buffer:
204210
chunk_byte_slice = self.index.get_chunk_slice(chunk_coords)
205211
if chunk_byte_slice:
206-
return self.buf[chunk_byte_slice[0] : (chunk_byte_slice[0] + chunk_byte_slice[1])]
212+
return self.buf[chunk_byte_slice[0] : chunk_byte_slice[1]]
207213
raise KeyError
208214

209215
def __len__(self) -> int:
@@ -504,7 +510,8 @@ async def _decode_partial_single(
504510
chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords)
505511
if chunk_byte_slice:
506512
chunk_bytes = await byte_getter.get(
507-
prototype=chunk_spec.prototype, byte_range=chunk_byte_slice
513+
prototype=chunk_spec.prototype,
514+
byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]),
508515
)
509516
if chunk_bytes:
510517
shard_dict[chunk_coords] = chunk_bytes
@@ -696,11 +703,12 @@ async def _load_shard_index_maybe(
696703
shard_index_size = self._shard_index_size(chunks_per_shard)
697704
if self.index_location == ShardingCodecIndexLocation.start:
698705
index_bytes = await byte_getter.get(
699-
prototype=numpy_buffer_prototype(), byte_range=(0, shard_index_size)
706+
prototype=numpy_buffer_prototype(),
707+
byte_range=RangeByteRequest(0, shard_index_size),
700708
)
701709
else:
702710
index_bytes = await byte_getter.get(
703-
prototype=numpy_buffer_prototype(), byte_range=(-shard_index_size, None)
711+
prototype=numpy_buffer_prototype(), byte_range=SuffixByteRequest(shard_index_size)
704712
)
705713
if index_bytes is not None:
706714
return await self._decode_shard_index(index_bytes, chunks_per_shard)

Diff for: src/zarr/core/common.py

-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
ZATTRS_JSON = ".zattrs"
3232
ZMETADATA_V2_JSON = ".zmetadata"
3333

34-
ByteRangeRequest = tuple[int | None, int | None]
3534
BytesLike = bytes | bytearray | memoryview
3635
ShapeLike = tuple[int, ...] | int
3736
ChunkCoords = tuple[int, ...]

Diff for: src/zarr/storage/_common.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pathlib import Path
55
from typing import TYPE_CHECKING, Any, Literal
66

7-
from zarr.abc.store import ByteRangeRequest, Store
7+
from zarr.abc.store import ByteRequest, Store
88
from zarr.core.buffer import Buffer, default_buffer_prototype
99
from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZGROUP_JSON, AccessModeLiteral, ZarrFormat
1010
from zarr.errors import ContainsArrayAndGroupError, ContainsArrayError, ContainsGroupError
@@ -102,7 +102,7 @@ async def open(
102102
async def get(
103103
self,
104104
prototype: BufferPrototype | None = None,
105-
byte_range: ByteRangeRequest | None = None,
105+
byte_range: ByteRequest | None = None,
106106
) -> Buffer | None:
107107
"""
108108
Read bytes from the store.
@@ -111,7 +111,7 @@ async def get(
111111
----------
112112
prototype : BufferPrototype, optional
113113
The buffer prototype to use when reading the bytes.
114-
byte_range : ByteRangeRequest, optional
114+
byte_range : ByteRequest, optional
115115
The range of bytes to read.
116116
117117
Returns
@@ -123,15 +123,15 @@ async def get(
123123
prototype = default_buffer_prototype()
124124
return await self.store.get(self.path, prototype=prototype, byte_range=byte_range)
125125

126-
async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None:
126+
async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None:
127127
"""
128128
Write bytes to the store.
129129
130130
Parameters
131131
----------
132132
value : Buffer
133133
The buffer to write.
134-
byte_range : ByteRangeRequest, optional
134+
byte_range : ByteRequest, optional
135135
The range of bytes to write. If None, the entire buffer is written.
136136
137137
Raises

Diff for: src/zarr/storage/_fsspec.py

+50-31
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import warnings
44
from typing import TYPE_CHECKING, Any
55

6-
from zarr.abc.store import ByteRangeRequest, Store
6+
from zarr.abc.store import (
7+
ByteRequest,
8+
OffsetByteRequest,
9+
RangeByteRequest,
10+
Store,
11+
SuffixByteRequest,
12+
)
713
from zarr.storage._common import _dereference_path
814

915
if TYPE_CHECKING:
@@ -199,31 +205,34 @@ async def get(
199205
self,
200206
key: str,
201207
prototype: BufferPrototype,
202-
byte_range: ByteRangeRequest | None = None,
208+
byte_range: ByteRequest | None = None,
203209
) -> Buffer | None:
204210
# docstring inherited
205211
if not self._is_open:
206212
await self._open()
207213
path = _dereference_path(self.path, key)
208214

209215
try:
210-
if byte_range:
211-
# fsspec uses start/end, not start/length
212-
start, length = byte_range
213-
if start is not None and length is not None:
214-
end = start + length
215-
elif length is not None:
216-
end = length
217-
else:
218-
end = None
219-
value = prototype.buffer.from_bytes(
220-
await (
221-
self.fs._cat_file(path, start=byte_range[0], end=end)
222-
if byte_range
223-
else self.fs._cat_file(path)
216+
if byte_range is None:
217+
value = prototype.buffer.from_bytes(await self.fs._cat_file(path))
218+
elif isinstance(byte_range, RangeByteRequest):
219+
value = prototype.buffer.from_bytes(
220+
await self.fs._cat_file(
221+
path,
222+
start=byte_range.start,
223+
end=byte_range.end,
224+
)
224225
)
225-
)
226-
226+
elif isinstance(byte_range, OffsetByteRequest):
227+
value = prototype.buffer.from_bytes(
228+
await self.fs._cat_file(path, start=byte_range.offset, end=None)
229+
)
230+
elif isinstance(byte_range, SuffixByteRequest):
231+
value = prototype.buffer.from_bytes(
232+
await self.fs._cat_file(path, start=-byte_range.suffix, end=None)
233+
)
234+
else:
235+
raise ValueError(f"Unexpected byte_range, got {byte_range}.")
227236
except self.allowed_exceptions:
228237
return None
229238
except OSError as e:
@@ -270,25 +279,35 @@ async def exists(self, key: str) -> bool:
270279
async def get_partial_values(
271280
self,
272281
prototype: BufferPrototype,
273-
key_ranges: Iterable[tuple[str, ByteRangeRequest]],
282+
key_ranges: Iterable[tuple[str, ByteRequest | None]],
274283
) -> list[Buffer | None]:
275284
# docstring inherited
276285
if key_ranges:
277-
paths, starts, stops = zip(
278-
*(
279-
(
280-
_dereference_path(self.path, k[0]),
281-
k[1][0],
282-
((k[1][0] or 0) + k[1][1]) if k[1][1] is not None else None,
283-
)
284-
for k in key_ranges
285-
),
286-
strict=False,
287-
)
286+
# _cat_ranges expects a list of paths, start, and end ranges, so we need to reformat each ByteRequest.
287+
key_ranges = list(key_ranges)
288+
paths: list[str] = []
289+
starts: list[int | None] = []
290+
stops: list[int | None] = []
291+
for key, byte_range in key_ranges:
292+
paths.append(_dereference_path(self.path, key))
293+
if byte_range is None:
294+
starts.append(None)
295+
stops.append(None)
296+
elif isinstance(byte_range, RangeByteRequest):
297+
starts.append(byte_range.start)
298+
stops.append(byte_range.end)
299+
elif isinstance(byte_range, OffsetByteRequest):
300+
starts.append(byte_range.offset)
301+
stops.append(None)
302+
elif isinstance(byte_range, SuffixByteRequest):
303+
starts.append(-byte_range.suffix)
304+
stops.append(None)
305+
else:
306+
raise ValueError(f"Unexpected byte_range, got {byte_range}.")
288307
else:
289308
return []
290309
# TODO: expectations for exceptions or missing keys?
291-
res = await self.fs._cat_ranges(list(paths), starts, stops, on_error="return")
310+
res = await self.fs._cat_ranges(paths, starts, stops, on_error="return")
292311
# the following is an s3-specific condition we probably don't want to leak
293312
res = [b"" if (isinstance(r, OSError) and "not satisfiable" in str(r)) else r for r in res]
294313
for r in res:

0 commit comments

Comments
 (0)