Skip to content

Commit fd32152

Browse files
authored
Add type hints for primitive functions (#236)
* Add top-level types * Add type hints to top-level utils Fix type errors in memory_repr * Add type hints to LazyZarrArray * Add type hints for primitive rechunk * Add type hints for vendored Dask functions * Add type hints for primitive blockwise
1 parent 25f5909 commit fd32152

File tree

8 files changed

+195
-102
lines changed

8 files changed

+195
-102
lines changed

cubed/primitive/blockwise.py

Lines changed: 64 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import itertools
2+
import math
23
from dataclasses import dataclass
34
from functools import partial
4-
from typing import Callable, Dict
5+
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
56

67
import toolz
78
import zarr
89
from toolz import map
910

10-
from cubed.storage.zarr import lazy_empty
11+
from cubed.storage.zarr import T_ZarrArray, lazy_empty
12+
from cubed.types import T_Chunks, T_DType, T_Shape, T_Store
1113
from cubed.utils import chunk_memory, get_item, to_chunksize
1214
from cubed.vendor.dask.array.core import normalize_chunks
1315
from cubed.vendor.dask.blockwise import _get_coord_mapping, _make_dims, lol_product
@@ -19,7 +21,7 @@
1921
sym_counter = 0
2022

2123

22-
def gensym(name):
24+
def gensym(name: str) -> str:
2325
global sym_counter
2426
sym_counter += 1
2527
return f"{name}-{sym_counter:03}"
@@ -43,19 +45,21 @@ class BlockwiseSpec:
4345
Write proxy with an ``array`` attribute that supports ``__setitem__``.
4446
"""
4547

46-
block_function: Callable
47-
function: Callable
48+
block_function: Callable[..., Any]
49+
function: Callable[..., Any]
4850
reads_map: Dict[str, CubedArrayProxy]
4951
write: CubedArrayProxy
5052

5153

52-
def apply_blockwise(out_key, *, config=BlockwiseSpec):
54+
def apply_blockwise(out_key: List[int], *, config: BlockwiseSpec) -> None:
5355
"""Stage function for blockwise."""
5456
# lithops needs params to be lists not tuples, so convert back
55-
out_key = tuple(out_key)
56-
out_chunk_key = key_to_slices(out_key, config.write.array, config.write.chunks)
57+
out_key_tuple = tuple(out_key)
58+
out_chunk_key = key_to_slices(
59+
out_key_tuple, config.write.array, config.write.chunks
60+
)
5761
args = []
58-
name_chunk_inds = config.block_function(("out",) + out_key)
62+
name_chunk_inds = config.block_function(("out",) + out_key_tuple)
5963
for name_chunk_ind in name_chunk_inds:
6064
name = name_chunk_ind[0]
6165
chunk_ind = name_chunk_ind[1:]
@@ -72,25 +76,27 @@ def apply_blockwise(out_key, *, config=BlockwiseSpec):
7276
config.write.open()[out_chunk_key] = result
7377

7478

75-
def key_to_slices(key, arr, chunks=None):
79+
def key_to_slices(
80+
key: Tuple[int, ...], arr: T_ZarrArray, chunks: Optional[T_Chunks] = None
81+
) -> Tuple[slice, ...]:
7682
"""Convert a chunk index key to a tuple of slices"""
7783
chunks = normalize_chunks(chunks or arr.chunks, shape=arr.shape, dtype=arr.dtype)
7884
return get_item(chunks, key)
7985

8086

8187
def blockwise(
82-
func,
83-
out_ind,
84-
*args,
85-
allowed_mem,
86-
reserved_mem,
87-
target_store,
88-
shape,
89-
dtype,
90-
chunks,
91-
new_axes=None,
92-
in_names=None,
93-
out_name=None,
88+
func: Callable[..., Any],
89+
out_ind: Sequence[Union[str, int]],
90+
*args: Any,
91+
allowed_mem: int,
92+
reserved_mem: int,
93+
target_store: T_Store,
94+
shape: T_Shape,
95+
dtype: T_DType,
96+
chunks: T_Chunks,
97+
new_axes: Optional[Dict[int, int]] = None,
98+
in_names: Optional[List[str]] = None,
99+
out_name: Optional[str] = None,
94100
**kwargs,
95101
):
96102
"""Apply a function across blocks from multiple source Zarr arrays.
@@ -126,20 +132,20 @@ def blockwise(
126132
"""
127133

128134
# Use dask's make_blockwise_graph
129-
arrays = args[::2]
135+
arrays: Sequence[T_ZarrArray] = args[::2]
130136
array_names = in_names or [f"in_{i}" for i in range(len(arrays))]
131137
array_map = {name: array for name, array in zip(array_names, arrays)}
132138

133-
inds = args[1::2]
139+
inds: Sequence[Union[str, int]] = args[1::2]
134140

135-
numblocks = {}
141+
numblocks: Dict[str, Tuple[int, ...]] = {}
136142
for name, array in zip(array_names, arrays):
137143
input_chunks = normalize_chunks(
138144
array.chunks, shape=array.shape, dtype=array.dtype
139145
)
140146
numblocks[name] = tuple(map(len, input_chunks))
141147

142-
argindsstr = []
148+
argindsstr: List[Any] = []
143149
for name, ind in zip(array_names, inds):
144150
argindsstr.extend((name, ind))
145151

@@ -228,21 +234,21 @@ def blockwise(
228234
# Code for fusing pipelines
229235

230236

231-
def is_fuse_candidate(pipeline):
237+
def is_fuse_candidate(pipeline: CubedPipeline) -> bool:
232238
"""
233239
Return True if a pipeline is a candidate for blockwise fusion.
234240
"""
235241
stages = pipeline.stages
236242
return len(stages) == 1 and stages[0].function == apply_blockwise
237243

238244

239-
def can_fuse_pipelines(pipeline1, pipeline2):
245+
def can_fuse_pipelines(pipeline1: CubedPipeline, pipeline2: CubedPipeline) -> bool:
240246
if is_fuse_candidate(pipeline1) and is_fuse_candidate(pipeline2):
241247
return pipeline1.num_tasks == pipeline2.num_tasks
242248
return False
243249

244250

245-
def fuse(pipeline1, pipeline2):
251+
def fuse(pipeline1: CubedPipeline, pipeline2: CubedPipeline) -> CubedPipeline:
246252
"""
247253
Fuse two blockwise pipelines into a single pipeline, avoiding writing to (or reading from) the target of the first pipeline.
248254
"""
@@ -282,8 +288,13 @@ def fused_func(*args):
282288

283289

284290
def make_blockwise_function(
285-
func, output, out_indices, *arrind_pairs, numblocks=None, new_axes=None
286-
):
291+
func: Callable[..., Any],
292+
output: str,
293+
out_indices: Sequence[Union[str, int]],
294+
*arrind_pairs: Any,
295+
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
296+
new_axes: Optional[Dict[int, int]] = None,
297+
) -> Callable[[List[int]], Any]:
287298
"""Make a function that is the equivalent of make_blockwise_graph."""
288299

289300
if numblocks is None:
@@ -335,8 +346,13 @@ def blockwise_fn(out_key):
335346

336347

337348
def make_blockwise_function_flattened(
338-
func, output, out_indices, *arrind_pairs, numblocks=None, new_axes=None
339-
):
349+
func: Callable[..., Any],
350+
output: str,
351+
out_indices: Sequence[Union[str, int]],
352+
*arrind_pairs: Any,
353+
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
354+
new_axes: Optional[Dict[int, int]] = None,
355+
) -> Callable[[List[int]], Any]:
340356
# TODO: make this a part of make_blockwise_function?
341357
blockwise_fn = make_blockwise_function(
342358
func, output, out_indices, *arrind_pairs, numblocks=numblocks, new_axes=new_axes
@@ -353,8 +369,13 @@ def blockwise_fn_flattened(out_key):
353369

354370

355371
def get_output_blocks(
356-
func, output, out_indices, *arrind_pairs, numblocks=None, new_axes=None
357-
):
372+
func: Callable[..., Any],
373+
output: str,
374+
out_indices: Sequence[Union[str, int]],
375+
*arrind_pairs: Any,
376+
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
377+
new_axes: Optional[Dict[int, int]] = None,
378+
) -> Iterator[List[int]]:
358379
if numblocks is None:
359380
raise ValueError("Missing required numblocks argument.")
360381
new_axes = new_axes or {}
@@ -369,24 +390,26 @@ def get_output_blocks(
369390

370391

371392
class IterableFromGenerator:
372-
def __init__(self, generator_fn):
393+
def __init__(self, generator_fn: Callable[[], Iterator[List[int]]]):
373394
self.generator_fn = generator_fn
374395

375396
def __iter__(self):
376397
return self.generator_fn()
377398

378399

379400
def num_output_blocks(
380-
func, output, out_indices, *arrind_pairs, numblocks=None, new_axes=None
381-
):
401+
func: Callable[..., Any],
402+
output: str,
403+
out_indices: Sequence[Union[str, int]],
404+
*arrind_pairs: Any,
405+
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
406+
new_axes: Optional[Dict[int, int]] = None,
407+
) -> int:
382408
if numblocks is None:
383409
raise ValueError("Missing required numblocks argument.")
384410
new_axes = new_axes or {}
385411
argpairs = list(toolz.partition(2, arrind_pairs))
386412

387413
# Dictionary mapping {i: 3, j: 4, ...} for i, j, ... the dimensions
388414
dims = _make_dims(argpairs, numblocks, new_axes)
389-
390-
import math
391-
392415
return math.prod(dims[i] for i in out_indices)

cubed/primitive/rechunk.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
from math import ceil, prod
2+
from typing import Any, Dict, List, Optional, Tuple
23

34
import zarr
45

5-
import cubed
6-
from cubed.primitive.types import CubedArrayProxy, CubedCopySpec
6+
from cubed.primitive.types import CubedArrayProxy, CubedCopySpec, CubedPipeline
77
from cubed.runtime.pipeline import spec_to_pipeline
8-
from cubed.storage.zarr import lazy_empty
8+
from cubed.storage.zarr import T_ZarrArray, lazy_empty
9+
from cubed.types import T_RegularChunks, T_Shape, T_Store
910
from cubed.vendor.rechunker.algorithm import rechunking_plan
1011
from cubed.vendor.rechunker.api import _validate_options
1112

1213

1314
def rechunk(
14-
source, target_chunks, allowed_mem, reserved_mem, target_store, temp_store=None
15-
):
15+
source: T_ZarrArray,
16+
target_chunks: T_RegularChunks,
17+
allowed_mem: int,
18+
reserved_mem: int,
19+
target_store: T_Store,
20+
temp_store: Optional[T_Store] = None,
21+
) -> CubedPipeline:
1622
"""Rechunk a Zarr array to have target_chunks.
1723
1824
Parameters
@@ -36,7 +42,7 @@ def rechunk(
3642

3743
# rechunker doesn't take account of uncompressed and compressed copies of the
3844
# input and output array chunk/selection, so adjust appropriately
39-
rechunker_max_mem = (allowed_mem - reserved_mem) / 4
45+
rechunker_max_mem = (allowed_mem - reserved_mem) // 4
4046

4147
copy_specs, intermediate, target = _setup_rechunk(
4248
source=source,
@@ -62,14 +68,14 @@ def rechunk(
6268

6369
# from rechunker, but simpler since it only has to handle Zarr arrays
6470
def _setup_rechunk(
65-
source,
66-
target_chunks,
67-
max_mem,
68-
target_store,
69-
target_options=None,
70-
temp_store=None,
71-
temp_options=None,
72-
):
71+
source: T_ZarrArray,
72+
target_chunks: T_RegularChunks,
73+
max_mem: int,
74+
target_store: T_Store,
75+
target_options: Optional[Dict[Any, Any]] = None,
76+
temp_store: Optional[T_Store] = None,
77+
temp_options: Optional[Dict[Any, Any]] = None,
78+
) -> Tuple[List[CubedCopySpec], T_ZarrArray, T_ZarrArray]:
7379
if temp_options is None:
7480
temp_options = target_options
7581
target_options = target_options or {}
@@ -90,14 +96,14 @@ def _setup_rechunk(
9096

9197

9298
def _setup_array_rechunk(
93-
source_array,
94-
target_chunks,
95-
max_mem,
96-
target_store_or_group,
97-
target_options=None,
98-
temp_store_or_group=None,
99-
temp_options=None,
100-
name=None,
99+
source_array: T_ZarrArray,
100+
target_chunks: T_RegularChunks,
101+
max_mem: int,
102+
target_store_or_group: T_Store,
103+
target_options: Optional[Dict[Any, Any]] = None,
104+
temp_store_or_group: Optional[T_Store] = None,
105+
temp_options: Optional[Dict[Any, Any]] = None,
106+
name: Optional[str] = None,
101107
) -> CubedCopySpec:
102108
_validate_options(target_options)
103109
_validate_options(temp_options)
@@ -115,9 +121,6 @@ def _setup_array_rechunk(
115121
# this is just a pass-through copy
116122
target_chunks = source_chunks
117123

118-
# TODO: rewrite to avoid the hard dependency on dask
119-
max_mem = cubed.vendor.dask.utils.parse_bytes(max_mem)
120-
121124
# don't consolidate reads for Dask arrays
122125
consolidate_reads = isinstance(source_array, zarr.core.Array)
123126
read_chunks, int_chunks, write_chunks = rechunking_plan(
@@ -143,11 +146,6 @@ def _setup_array_rechunk(
143146
**(target_options or {}),
144147
)
145148

146-
try:
147-
target_array.attrs.update(source_array.attrs)
148-
except AttributeError:
149-
pass
150-
151149
if read_chunks == write_chunks:
152150
int_array = None
153151
else:
@@ -172,6 +170,6 @@ def _setup_array_rechunk(
172170
return CubedCopySpec(read_proxy, int_proxy, write_proxy)
173171

174172

175-
def total_chunks(shape, chunks):
173+
def total_chunks(shape: T_Shape, chunks: T_RegularChunks) -> int:
176174
# cf rechunker's chunk_keys
177175
return prod(ceil(s / c) for s, c in zip(shape, chunks))

cubed/primitive/types.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
from dataclasses import dataclass
2-
from typing import Any, Iterable, Optional
2+
from typing import Any, Optional, Sequence
33

4-
from cubed.storage.zarr import open_if_lazy_zarr_array
4+
import zarr
5+
6+
from cubed.storage.zarr import T_ZarrArray, open_if_lazy_zarr_array
7+
from cubed.types import T_RegularChunks
58
from cubed.vendor.rechunker.types import Config, Stage
69

710

811
@dataclass(frozen=True)
912
class CubedPipeline:
1013
"""Generalisation of rechunker ``Pipeline`` with extra attributes."""
1114

12-
stages: Iterable[Stage]
15+
stages: Sequence[Stage]
1316
config: Config
1417
target_array: Any
1518
intermediate_array: Optional[Any]
@@ -20,11 +23,11 @@ class CubedPipeline:
2023
class CubedArrayProxy:
2124
"""Generalisation of rechunker ``ArrayProxy`` with support for ``LazyZarrArray``."""
2225

23-
def __init__(self, array, chunks):
26+
def __init__(self, array: T_ZarrArray, chunks: T_RegularChunks):
2427
self.array = array
2528
self.chunks = chunks
2629

27-
def open(self):
30+
def open(self) -> zarr.Array:
2831
return open_if_lazy_zarr_array(self.array)
2932

3033

0 commit comments

Comments
 (0)