Skip to content

Commit 73b884b

Browse files
TomAugspurgerd-v-bjhammanpre-commit-ci[bot]
authored
Allow mode casting for Stores (#2249)
* Allow mode casting * fixup * fixup * fixup * fixup * match message * Update src/zarr/testing/store.py Co-authored-by: Davis Bennett <[email protected]> * fixup * fixup * fixup * fixup * pre-commit * log methods * style: pre-commit fixes --------- Co-authored-by: Davis Bennett <[email protected]> Co-authored-by: Joe Hamman <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2edc548 commit 73b884b

File tree

11 files changed

+176
-19
lines changed

11 files changed

+176
-19
lines changed

Diff for: src/zarr/abc/store.py

+25
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,31 @@ async def empty(self) -> bool: ...
8888
@abstractmethod
8989
async def clear(self) -> None: ...
9090

91+
@abstractmethod
92+
def with_mode(self, mode: AccessModeLiteral) -> Self:
93+
"""
94+
Return a new store of the same type pointing to the same location with a new mode.
95+
96+
The returned Store is not automatically opened. Call :meth:`Store.open` before
97+
using.
98+
99+
Parameters
100+
----------
101+
mode: AccessModeLiteral
102+
The new mode to use.
103+
104+
Returns
105+
-------
106+
store:
107+
A new store of the same type with the new mode.
108+
109+
Examples
110+
--------
111+
>>> writer = zarr.store.MemoryStore(mode="w")
112+
>>> reader = writer.with_mode("r")
113+
"""
114+
...
115+
91116
@property
92117
def mode(self) -> AccessMode:
93118
"""Access mode of the store."""

Diff for: src/zarr/store/common.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ async def make_store_path(
9292
assert AccessMode.from_literal(mode) == store_like.store.mode
9393
result = store_like
9494
elif isinstance(store_like, Store):
95-
if mode is not None:
96-
assert AccessMode.from_literal(mode) == store_like.mode
95+
if mode is not None and mode != store_like.mode.str:
96+
store_like = store_like.with_mode(mode)
9797
await store_like._ensure_open()
9898
result = StorePath(store_like)
9999
elif store_like is None:

Diff for: src/zarr/store/local.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import shutil
66
from pathlib import Path
7-
from typing import TYPE_CHECKING
7+
from typing import TYPE_CHECKING, Self
88

99
from zarr.abc.store import ByteRangeRequest, Store
1010
from zarr.core.buffer import Buffer
@@ -110,6 +110,9 @@ async def empty(self) -> bool:
110110
else:
111111
return True
112112

113+
def with_mode(self, mode: AccessModeLiteral) -> Self:
114+
return type(self)(root=self.root, mode=mode)
115+
113116
def __str__(self) -> str:
114117
return f"file://{self.root}"
115118

Diff for: src/zarr/store/logging.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
from collections import defaultdict
77
from contextlib import contextmanager
8-
from typing import TYPE_CHECKING
8+
from typing import TYPE_CHECKING, Self
99

1010
from zarr.abc.store import AccessMode, ByteRangeRequest, Store
1111
from zarr.core.buffer import Buffer
@@ -14,6 +14,7 @@
1414
from collections.abc import AsyncGenerator, Generator, Iterable
1515

1616
from zarr.core.buffer import Buffer, BufferPrototype
17+
from zarr.core.common import AccessModeLiteral
1718

1819

1920
class LoggingStore(Store):
@@ -28,6 +29,8 @@ def __init__(
2829
) -> None:
2930
self._store = store
3031
self.counter = defaultdict(int)
32+
self.log_level = log_level
33+
self.log_handler = log_handler
3134

3235
self._configure_logger(log_level, log_handler)
3336

@@ -96,6 +99,14 @@ def _is_open(self) -> bool: # type: ignore[override]
9699
with self.log():
97100
return self._store._is_open
98101

102+
async def _open(self) -> None:
103+
with self.log():
104+
return await self._store._open()
105+
106+
async def _ensure_open(self) -> None:
107+
with self.log():
108+
return await self._store._ensure_open()
109+
99110
async def empty(self) -> bool:
100111
with self.log():
101112
return await self._store.empty()
@@ -167,3 +178,11 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
167178
with self.log():
168179
async for key in self._store.list_dir(prefix=prefix):
169180
yield key
181+
182+
def with_mode(self, mode: AccessModeLiteral) -> Self:
183+
with self.log():
184+
return type(self)(
185+
self._store.with_mode(mode),
186+
log_level=self.log_level,
187+
log_handler=self.log_handler,
188+
)

Diff for: src/zarr/store/memory.py

+41-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Self
44

55
from zarr.abc.store import ByteRangeRequest, Store
66
from zarr.core.buffer import Buffer, gpu
@@ -41,6 +41,9 @@ async def empty(self) -> bool:
4141
async def clear(self) -> None:
4242
self._store_dict.clear()
4343

44+
def with_mode(self, mode: AccessModeLiteral) -> Self:
45+
return type(self)(store_dict=self._store_dict, mode=mode)
46+
4447
def __str__(self) -> str:
4548
return f"memory://{id(self._store_dict)}"
4649

@@ -156,29 +159,58 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
156159

157160
class GpuMemoryStore(MemoryStore):
158161
"""A GPU only memory store that stores every chunk in GPU memory irrespective
159-
of the original location. This guarantees that chunks will always be in GPU
160-
memory for downstream processing. For location agnostic use cases, it would
161-
be better to use `MemoryStore` instead.
162+
of the original location.
163+
164+
The dictionary of buffers to initialize this memory store with *must* be
165+
GPU Buffers.
166+
167+
Writing data to this store through ``.set`` will move the buffer to the GPU
168+
if necessary.
169+
170+
Parameters
171+
----------
172+
store_dict: MutableMapping, optional
173+
A mutable mapping with string keys and :class:`zarr.core.buffer.gpu.Buffer`
174+
values.
162175
"""
163176

164-
_store_dict: MutableMapping[str, Buffer]
177+
_store_dict: MutableMapping[str, gpu.Buffer] # type: ignore[assignment]
165178

166179
def __init__(
167180
self,
168-
store_dict: MutableMapping[str, Buffer] | None = None,
181+
store_dict: MutableMapping[str, gpu.Buffer] | None = None,
169182
*,
170183
mode: AccessModeLiteral = "r",
171184
) -> None:
172-
super().__init__(mode=mode)
173-
if store_dict:
174-
self._store_dict = {k: gpu.Buffer.from_buffer(store_dict[k]) for k in iter(store_dict)}
185+
super().__init__(store_dict=store_dict, mode=mode) # type: ignore[arg-type]
175186

176187
def __str__(self) -> str:
177188
return f"gpumemory://{id(self._store_dict)}"
178189

179190
def __repr__(self) -> str:
180191
return f"GpuMemoryStore({str(self)!r})"
181192

193+
@classmethod
194+
def from_dict(cls, store_dict: MutableMapping[str, Buffer]) -> Self:
195+
"""
196+
Create a GpuMemoryStore from a dictionary of buffers at any location.
197+
198+
The dictionary backing the newly created ``GpuMemoryStore`` will not be
199+
the same as ``store_dict``.
200+
201+
Parameters
202+
----------
203+
store_dict: mapping
204+
A mapping of strings keys to arbitrary Buffers. The buffer data
205+
will be moved into a :class:`gpu.Buffer`.
206+
207+
Returns
208+
-------
209+
GpuMemoryStore
210+
"""
211+
gpu_store_dict = {k: gpu.Buffer.from_buffer(v) for k, v in store_dict.items()}
212+
return cls(gpu_store_dict)
213+
182214
async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
183215
self._check_writable()
184216
assert isinstance(key, str)

Diff for: src/zarr/store/remote.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any
3+
from typing import TYPE_CHECKING, Any, Self
44

55
import fsspec
66

@@ -96,6 +96,14 @@ async def clear(self) -> None:
9696
async def empty(self) -> bool:
9797
return not await self.fs._find(self.path, withdirs=True)
9898

99+
def with_mode(self, mode: AccessModeLiteral) -> Self:
100+
return type(self)(
101+
fs=self.fs,
102+
mode=mode,
103+
path=self.path,
104+
allowed_exceptions=self.allowed_exceptions,
105+
)
106+
99107
def __repr__(self) -> str:
100108
return f"<RemoteStore({type(self.fs).__name__}, {self.path})>"
101109

Diff for: src/zarr/store/zip.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
import zipfile
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, Any, Literal
8+
from typing import TYPE_CHECKING, Any, Literal, Self
99

1010
from zarr.abc.store import ByteRangeRequest, Store
1111
from zarr.core.buffer import Buffer, BufferPrototype
@@ -112,6 +112,9 @@ async def empty(self) -> bool:
112112
with self._lock:
113113
return not self._zf.namelist()
114114

115+
def with_mode(self, mode: ZipStoreAccessModeLiteral) -> Self: # type: ignore[override]
116+
raise NotImplementedError("ZipStore cannot be reopened with a new mode.")
117+
115118
def __str__(self) -> str:
116119
return f"zip://{self.path}"
117120

Diff for: src/zarr/testing/store.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import pickle
2-
from typing import Any, Generic, TypeVar
2+
from typing import Any, Generic, TypeVar, cast
33

44
import pytest
55

66
from zarr.abc.store import AccessMode, Store
77
from zarr.core.buffer import Buffer, default_buffer_prototype
8+
from zarr.core.common import AccessModeLiteral
89
from zarr.core.sync import _collect_aiterator, collect_aiterator
910
from zarr.store._utils import _normalize_interval_index
1011
from zarr.testing.utils import assert_bytes_equal
@@ -274,6 +275,41 @@ async def test_list_dir(self, store: S) -> None:
274275
keys_observed = await _collect_aiterator(store.list_dir(root + "/"))
275276
assert sorted(keys_expected) == sorted(keys_observed)
276277

278+
async def test_with_mode(self, store: S) -> None:
279+
data = b"0000"
280+
self.set(store, "key", self.buffer_cls.from_bytes(data))
281+
assert self.get(store, "key").to_bytes() == data
282+
283+
for mode in ["r", "a"]:
284+
mode = cast(AccessModeLiteral, mode)
285+
clone = store.with_mode(mode)
286+
# await store.close()
287+
await clone._ensure_open()
288+
assert clone.mode == AccessMode.from_literal(mode)
289+
assert isinstance(clone, type(store))
290+
291+
# earlier writes are visible
292+
result = await clone.get("key", default_buffer_prototype())
293+
assert result is not None
294+
assert result.to_bytes() == data
295+
296+
# writes to original after with_mode is visible
297+
self.set(store, "key-2", self.buffer_cls.from_bytes(data))
298+
result = await clone.get("key-2", default_buffer_prototype())
299+
assert result is not None
300+
assert result.to_bytes() == data
301+
302+
if mode == "a":
303+
# writes to clone is visible in the original
304+
await clone.set("key-3", self.buffer_cls.from_bytes(data))
305+
result = await clone.get("key-3", default_buffer_prototype())
306+
assert result is not None
307+
assert result.to_bytes() == data
308+
309+
else:
310+
with pytest.raises(ValueError, match="store mode"):
311+
await clone.set("key-3", self.buffer_cls.from_bytes(data))
312+
277313
async def test_set_if_not_exists(self, store: S) -> None:
278314
key = "k"
279315
data_buf = self.buffer_cls.from_bytes(b"0000")

Diff for: tests/v3/test_store/test_logging.py

+8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
import zarr
8+
import zarr.store
89
from zarr.core.buffer import default_buffer_prototype
910
from zarr.store.logging import LoggingStore
1011

@@ -48,3 +49,10 @@ async def test_logging_store_counter(store: Store) -> None:
4849
assert wrapped.counter["list"] == 0
4950
assert wrapped.counter["list_dir"] == 0
5051
assert wrapped.counter["list_prefix"] == 0
52+
53+
54+
async def test_with_mode():
55+
wrapped = LoggingStore(store=zarr.store.MemoryStore(mode="w"), log_level="INFO")
56+
new = wrapped.with_mode(mode="r")
57+
assert new.mode.str == "r"
58+
assert new.log_level == "INFO"

Diff for: tests/v3/test_store/test_memory.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,14 @@ def set(self, store: GpuMemoryStore, key: str, value: Buffer) -> None:
5858
def get(self, store: MemoryStore, key: str) -> Buffer:
5959
return store._store_dict[key]
6060

61-
@pytest.fixture(params=[None, {}])
62-
def store_kwargs(self, request) -> dict[str, str | None | dict[str, Buffer]]:
63-
return {"store_dict": request.param, "mode": "r+"}
61+
@pytest.fixture(params=[None, True])
62+
def store_kwargs(
63+
self, request: pytest.FixtureRequest
64+
) -> dict[str, str | None | dict[str, Buffer]]:
65+
kwargs = {"store_dict": None, "mode": "r+"}
66+
if request.param is True:
67+
kwargs["store_dict"] = {}
68+
return kwargs
6469

6570
@pytest.fixture
6671
def store(self, store_kwargs: str | None | dict[str, gpu.Buffer]) -> GpuMemoryStore:
@@ -80,3 +85,17 @@ def test_store_supports_partial_writes(self, store: GpuMemoryStore) -> None:
8085

8186
def test_list_prefix(self, store: GpuMemoryStore) -> None:
8287
assert True
88+
89+
def test_dict_reference(self, store: GpuMemoryStore) -> None:
90+
store_dict = {}
91+
result = GpuMemoryStore(store_dict=store_dict)
92+
assert result._store_dict is store_dict
93+
94+
def test_from_dict(self):
95+
d = {
96+
"a": gpu.Buffer.from_bytes(b"aaaa"),
97+
"b": cpu.Buffer.from_bytes(b"bbbb"),
98+
}
99+
result = GpuMemoryStore.from_dict(d)
100+
for v in result._store_dict.values():
101+
assert type(v) is gpu.Buffer

Diff for: tests/v3/test_store/test_zip.py

+4
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,7 @@ def test_api_integration(self, store: ZipStore) -> None:
9696
del root["bar"]
9797

9898
store.close()
99+
100+
async def test_with_mode(self, store: ZipStore) -> None:
101+
with pytest.raises(NotImplementedError, match="new mode"):
102+
await super().test_with_mode(store)

0 commit comments

Comments
 (0)