Skip to content

Commit 5fae32c

Browse files
authored
Pass kwargs to tensorstore (#332)
* pass kwargs to tensorstore * style
1 parent a165513 commit 5fae32c

File tree

3 files changed

+31
-15
lines changed

3 files changed

+31
-15
lines changed

iohub/ngff/nodes.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -383,10 +383,17 @@ def dask_array(self):
383383
def downscale(self):
384384
raise NotImplementedError
385385

386-
def tensorstore(self, concurrency: int | None = None):
387-
"""Open the zarr array as a TensorStore object.
386+
def tensorstore(self, **kwargs):
387+
"""
388+
Open the zarr array as a TensorStore object.
388389
Needs the optional dependency ``tensorstore``.
389390
391+
Parameters
392+
----------
393+
**kwargs : dict, optional
394+
Additional keyword arguments to pass to ``tensorstore.open()``,
395+
by default None
396+
390397
Returns
391398
-------
392399
TensorStore
@@ -401,15 +408,10 @@ def tensorstore(self, concurrency: int | None = None):
401408
"path": str(Path(self.store.root) / self.path.strip("/")),
402409
},
403410
}
411+
if "read" in kwargs or "write" in kwargs:
412+
raise ValueError("Cannot override file mode for the Zarr store.")
404413
zarr_dataset = ts.open(
405-
ts_spec,
406-
read=True,
407-
write=not self.read_only,
408-
context=(
409-
ts.Context({"data_copy_concurrency": {"limit": concurrency}})
410-
if concurrency
411-
else None
412-
),
414+
ts_spec, read=True, write=not self.read_only, **kwargs
413415
).result()
414416
return zarr_dataset
415417

iohub/ngff/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,16 @@ def _save_transformed(
196196
output_time_indices: int | list[int],
197197
) -> None:
198198
# NOTE: use tensorstore due to zarr-python#3221
199+
import tensorstore
200+
199201
with open_ome_zarr(
200202
output_position_path, layout="fov", mode="r+"
201203
) as output_dataset:
202-
ts = output_dataset.data.tensorstore(concurrency=4)
204+
ts = output_dataset.data.tensorstore(
205+
context=tensorstore.Context(
206+
{"data_copy_concurrency": {"limit": 4}}
207+
)
208+
)
203209
ts.oindex[output_time_indices, output_channel_indices].write(
204210
transformed
205211
).result()

tests/ngff/test_ngff.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -475,11 +475,19 @@ def test_ome_zarr_to_tensorstore(
475475
channels_and_random_5d, arr_name, version, concurrency
476476
):
477477
"""Test `iohub.ngff.Position.data` to tensorstore"""
478+
import tensorstore as ts
479+
478480
channel_names, random_5d = channels_and_random_5d
479481
with _temp_ome_zarr(
480482
random_5d, channel_names, arr_name, version=version
481483
) as dataset:
482-
tstore = dataset[arr_name].tensorstore(concurrency=concurrency)
484+
tstore = dataset[arr_name].tensorstore(
485+
context=(
486+
ts.Context({"data_copy_concurrency": {"limit": concurrency}})
487+
if concurrency is not None
488+
else None
489+
)
490+
)
483491
assert_array_equal(tstore, random_5d)
484492
zeros = np.zeros_like(random_5d)
485493
tstore[...].write(zeros).result()
@@ -1288,9 +1296,9 @@ def test_hcs_external_reader(tmp_path):
12881296
fov.create_zeros("0", shape=(1, 2, 3, y_size, x_size), dtype=int)
12891297
n_rows = len(dataset.metadata.rows)
12901298
n_cols = len(dataset.metadata.columns)
1291-
plate = list(
1292-
ome_zarr.reader.Reader(ome_zarr.io.parse_url(store_path))()
1293-
)[0]
1299+
plate = list(ome_zarr.reader.Reader(ome_zarr.io.parse_url(store_path))())[
1300+
0
1301+
]
12941302
assert plate.data[0].shape == (1, 2, 3, y_size * n_rows, x_size * n_cols)
12951303
assert plate.data[0].dtype == int
12961304
assert not plate.data[0].any()

0 commit comments

Comments
 (0)