|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from collections.abc import Sequence |
3 | 4 | from pathlib import Path |
4 | 5 | from typing import Any, Literal |
5 | 6 |
|
|
38 | 39 | ) |
39 | 40 |
|
40 | 41 |
|
| 42 | +def _is_flat_int_sequence(value: object) -> bool: |
| 43 | + if isinstance(value, str | bytes): |
| 44 | + return False |
| 45 | + if not isinstance(value, Sequence): |
| 46 | + return False |
| 47 | + return all(isinstance(v, int) for v in value) |
| 48 | + |
| 49 | + |
| 50 | +def _is_dask_chunk_grid(value: object) -> bool: |
| 51 | + if isinstance(value, str | bytes): |
| 52 | + return False |
| 53 | + if not isinstance(value, Sequence): |
| 54 | + return False |
| 55 | + return len(value) > 0 and all(_is_flat_int_sequence(axis_chunks) for axis_chunks in value) |
| 56 | + |
| 57 | + |
| 58 | +def _is_regular_dask_chunk_grid(chunk_grid: Sequence[Sequence[int]]) -> bool: |
| 59 | + # Match Dask's private _check_regular_chunks() logic without depending on its internal API. |
| 60 | + for axis_chunks in chunk_grid: |
| 61 | + if len(axis_chunks) <= 1: |
| 62 | + continue |
| 63 | + if len(set(axis_chunks[:-1])) > 1: |
| 64 | + return False |
| 65 | + if axis_chunks[-1] > axis_chunks[0]: |
| 66 | + return False |
| 67 | + return True |
| 68 | + |
| 69 | + |
| 70 | +def _chunks_to_zarr_chunks(chunks: object) -> tuple[int, ...] | int | None: |
| 71 | + if isinstance(chunks, int): |
| 72 | + return chunks |
| 73 | + if _is_flat_int_sequence(chunks): |
| 74 | + return tuple(chunks) |
| 75 | + if _is_dask_chunk_grid(chunks): |
| 76 | + chunk_grid = tuple(tuple(axis_chunks) for axis_chunks in chunks) |
| 77 | + if _is_regular_dask_chunk_grid(chunk_grid): |
| 78 | + return tuple(axis_chunks[0] for axis_chunks in chunk_grid) |
| 79 | + return None |
| 80 | + return None |
| 81 | + |
| 82 | + |
| 83 | +def _normalize_explicit_chunks(chunks: object) -> tuple[int, ...] | int: |
| 84 | + normalized = _chunks_to_zarr_chunks(chunks) |
| 85 | + if normalized is None: |
| 86 | + raise ValueError( |
| 87 | + "storage_options['chunks'] must be a Zarr chunk shape or a regular Dask chunk grid. " |
| 88 | + "Irregular Dask chunk grids must be rechunked before writing or omitted." |
| 89 | + ) |
| 90 | + return normalized |
| 91 | + |
| 92 | + |
| 93 | +def _prepare_single_scale_storage_options( |
| 94 | + storage_options: JSONDict | list[JSONDict] | None, |
| 95 | +) -> JSONDict | list[JSONDict] | None: |
| 96 | + if storage_options is None: |
| 97 | + return None |
| 98 | + if isinstance(storage_options, dict): |
| 99 | + prepared = dict(storage_options) |
| 100 | + if "chunks" in prepared: |
| 101 | + prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"]) |
| 102 | + return prepared |
| 103 | + return [dict(options) for options in storage_options] |
| 104 | + |
| 105 | + |
| 106 | +def _prepare_multiscale_storage_options( |
| 107 | + storage_options: JSONDict | list[JSONDict] | None, |
| 108 | +) -> JSONDict | list[JSONDict] | None: |
| 109 | + if storage_options is None: |
| 110 | + return None |
| 111 | + if isinstance(storage_options, dict): |
| 112 | + prepared = dict(storage_options) |
| 113 | + if "chunks" in prepared: |
| 114 | + prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"]) |
| 115 | + return prepared |
| 116 | + |
| 117 | + prepared_options = [dict(options) for options in storage_options] |
| 118 | + for options in prepared_options: |
| 119 | + if "chunks" in options: |
| 120 | + options["chunks"] = _normalize_explicit_chunks(options["chunks"]) |
| 121 | + return prepared_options |
| 122 | + |
| 123 | + |
41 | 124 | def _read_multiscale( |
42 | 125 | store: str | Path, raster_type: Literal["image", "labels"], reader_format: Format |
43 | 126 | ) -> DataArray | DataTree: |
@@ -251,13 +334,8 @@ def _write_raster_dataarray( |
251 | 334 | if transformations is None: |
252 | 335 | raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") |
253 | 336 | input_axes: tuple[str, ...] = tuple(raster_data.dims) |
254 | | - chunks = raster_data.chunks |
255 | 337 | parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) |
256 | | - if storage_options is not None: |
257 | | - if "chunks" not in storage_options and isinstance(storage_options, dict): |
258 | | - storage_options["chunks"] = chunks |
259 | | - else: |
260 | | - storage_options = {"chunks": chunks} |
| 338 | + storage_options = _prepare_single_scale_storage_options(storage_options) |
261 | 339 | # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. |
262 | 340 | # We need this because the argument of write_image_ngff is called image while the argument of |
263 | 341 | # write_labels_ngff is called label. |
@@ -322,10 +400,9 @@ def _write_raster_datatree( |
322 | 400 | transformations = _get_transformations_xarray(xdata) |
323 | 401 | if transformations is None: |
324 | 402 | raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") |
325 | | - chunks = get_pyramid_levels(raster_data, "chunks") |
326 | 403 |
|
327 | 404 | parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) |
328 | | - storage_options = [{"chunks": chunk} for chunk in chunks] |
| 405 | + storage_options = _prepare_multiscale_storage_options(storage_options) |
329 | 406 | ome_zarr_format = get_ome_zarr_format(raster_format) |
330 | 407 | dask_delayed = write_multi_scale_ngff( |
331 | 408 | pyramid=data, |
|
0 commit comments