Skip to content

Commit

Permalink
Add type hints for primitive functions (#236)
Browse files Browse the repository at this point in the history
* Add top-level types

* Add type hints to top-level utils

Fix type errors in memory_repr

* Add type hints to LazyZarrArray

* Add type hints for primitive rechunk

* Add type hints for vendored Dask functions

* Add type hints for primitive blockwise
  • Loading branch information
tomwhite authored Jul 3, 2023
1 parent 25f5909 commit fd32152
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 102 deletions.
105 changes: 64 additions & 41 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import itertools
import math
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union

import toolz
import zarr
from toolz import map

from cubed.storage.zarr import lazy_empty
from cubed.storage.zarr import T_ZarrArray, lazy_empty
from cubed.types import T_Chunks, T_DType, T_Shape, T_Store
from cubed.utils import chunk_memory, get_item, to_chunksize
from cubed.vendor.dask.array.core import normalize_chunks
from cubed.vendor.dask.blockwise import _get_coord_mapping, _make_dims, lol_product
Expand All @@ -19,7 +21,7 @@
sym_counter = 0


def gensym(name):
def gensym(name: str) -> str:
global sym_counter
sym_counter += 1
return f"{name}-{sym_counter:03}"
Expand All @@ -43,19 +45,21 @@ class BlockwiseSpec:
Write proxy with an ``array`` attribute that supports ``__setitem__``.
"""

block_function: Callable
function: Callable
block_function: Callable[..., Any]
function: Callable[..., Any]
reads_map: Dict[str, CubedArrayProxy]
write: CubedArrayProxy


def apply_blockwise(out_key, *, config=BlockwiseSpec):
def apply_blockwise(out_key: List[int], *, config: BlockwiseSpec) -> None:
"""Stage function for blockwise."""
# lithops needs params to be lists not tuples, so convert back
out_key = tuple(out_key)
out_chunk_key = key_to_slices(out_key, config.write.array, config.write.chunks)
out_key_tuple = tuple(out_key)
out_chunk_key = key_to_slices(
out_key_tuple, config.write.array, config.write.chunks
)
args = []
name_chunk_inds = config.block_function(("out",) + out_key)
name_chunk_inds = config.block_function(("out",) + out_key_tuple)
for name_chunk_ind in name_chunk_inds:
name = name_chunk_ind[0]
chunk_ind = name_chunk_ind[1:]
Expand All @@ -72,25 +76,27 @@ def apply_blockwise(out_key, *, config=BlockwiseSpec):
config.write.open()[out_chunk_key] = result


def key_to_slices(key, arr, chunks=None):
def key_to_slices(
key: Tuple[int, ...], arr: T_ZarrArray, chunks: Optional[T_Chunks] = None
) -> Tuple[slice, ...]:
"""Convert a chunk index key to a tuple of slices"""
chunks = normalize_chunks(chunks or arr.chunks, shape=arr.shape, dtype=arr.dtype)
return get_item(chunks, key)


def blockwise(
func,
out_ind,
*args,
allowed_mem,
reserved_mem,
target_store,
shape,
dtype,
chunks,
new_axes=None,
in_names=None,
out_name=None,
func: Callable[..., Any],
out_ind: Sequence[Union[str, int]],
*args: Any,
allowed_mem: int,
reserved_mem: int,
target_store: T_Store,
shape: T_Shape,
dtype: T_DType,
chunks: T_Chunks,
new_axes: Optional[Dict[int, int]] = None,
in_names: Optional[List[str]] = None,
out_name: Optional[str] = None,
**kwargs,
):
"""Apply a function across blocks from multiple source Zarr arrays.
Expand Down Expand Up @@ -126,20 +132,20 @@ def blockwise(
"""

# Use dask's make_blockwise_graph
arrays = args[::2]
arrays: Sequence[T_ZarrArray] = args[::2]
array_names = in_names or [f"in_{i}" for i in range(len(arrays))]
array_map = {name: array for name, array in zip(array_names, arrays)}

inds = args[1::2]
inds: Sequence[Union[str, int]] = args[1::2]

numblocks = {}
numblocks: Dict[str, Tuple[int, ...]] = {}
for name, array in zip(array_names, arrays):
input_chunks = normalize_chunks(
array.chunks, shape=array.shape, dtype=array.dtype
)
numblocks[name] = tuple(map(len, input_chunks))

argindsstr = []
argindsstr: List[Any] = []
for name, ind in zip(array_names, inds):
argindsstr.extend((name, ind))

Expand Down Expand Up @@ -228,21 +234,21 @@ def blockwise(
# Code for fusing pipelines


def is_fuse_candidate(pipeline):
def is_fuse_candidate(pipeline: CubedPipeline) -> bool:
"""
Return True if a pipeline is a candidate for blockwise fusion.
"""
stages = pipeline.stages
return len(stages) == 1 and stages[0].function == apply_blockwise


def can_fuse_pipelines(pipeline1, pipeline2):
def can_fuse_pipelines(pipeline1: CubedPipeline, pipeline2: CubedPipeline) -> bool:
if is_fuse_candidate(pipeline1) and is_fuse_candidate(pipeline2):
return pipeline1.num_tasks == pipeline2.num_tasks
return False


def fuse(pipeline1, pipeline2):
def fuse(pipeline1: CubedPipeline, pipeline2: CubedPipeline) -> CubedPipeline:
"""
Fuse two blockwise pipelines into a single pipeline, avoiding writing to (or reading from) the target of the first pipeline.
"""
Expand Down Expand Up @@ -282,8 +288,13 @@ def fused_func(*args):


def make_blockwise_function(
func, output, out_indices, *arrind_pairs, numblocks=None, new_axes=None
):
func: Callable[..., Any],
output: str,
out_indices: Sequence[Union[str, int]],
*arrind_pairs: Any,
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
new_axes: Optional[Dict[int, int]] = None,
) -> Callable[[List[int]], Any]:
"""Make a function that is the equivalent of make_blockwise_graph."""

if numblocks is None:
Expand Down Expand Up @@ -335,8 +346,13 @@ def blockwise_fn(out_key):


def make_blockwise_function_flattened(
func, output, out_indices, *arrind_pairs, numblocks=None, new_axes=None
):
func: Callable[..., Any],
output: str,
out_indices: Sequence[Union[str, int]],
*arrind_pairs: Any,
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
new_axes: Optional[Dict[int, int]] = None,
) -> Callable[[List[int]], Any]:
# TODO: make this a part of make_blockwise_function?
blockwise_fn = make_blockwise_function(
func, output, out_indices, *arrind_pairs, numblocks=numblocks, new_axes=new_axes
Expand All @@ -353,8 +369,13 @@ def blockwise_fn_flattened(out_key):


def get_output_blocks(
func, output, out_indices, *arrind_pairs, numblocks=None, new_axes=None
):
func: Callable[..., Any],
output: str,
out_indices: Sequence[Union[str, int]],
*arrind_pairs: Any,
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
new_axes: Optional[Dict[int, int]] = None,
) -> Iterator[List[int]]:
if numblocks is None:
raise ValueError("Missing required numblocks argument.")
new_axes = new_axes or {}
Expand All @@ -369,24 +390,26 @@ def get_output_blocks(


class IterableFromGenerator:
def __init__(self, generator_fn):
def __init__(self, generator_fn: Callable[[], Iterator[List[int]]]):
self.generator_fn = generator_fn

def __iter__(self):
return self.generator_fn()


def num_output_blocks(
func, output, out_indices, *arrind_pairs, numblocks=None, new_axes=None
):
func: Callable[..., Any],
output: str,
out_indices: Sequence[Union[str, int]],
*arrind_pairs: Any,
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
new_axes: Optional[Dict[int, int]] = None,
) -> int:
if numblocks is None:
raise ValueError("Missing required numblocks argument.")
new_axes = new_axes or {}
argpairs = list(toolz.partition(2, arrind_pairs))

# Dictionary mapping {i: 3, j: 4, ...} for i, j, ... the dimensions
dims = _make_dims(argpairs, numblocks, new_axes)

import math

return math.prod(dims[i] for i in out_indices)
60 changes: 29 additions & 31 deletions cubed/primitive/rechunk.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
from math import ceil, prod
from typing import Any, Dict, List, Optional, Tuple

import zarr

import cubed
from cubed.primitive.types import CubedArrayProxy, CubedCopySpec
from cubed.primitive.types import CubedArrayProxy, CubedCopySpec, CubedPipeline
from cubed.runtime.pipeline import spec_to_pipeline
from cubed.storage.zarr import lazy_empty
from cubed.storage.zarr import T_ZarrArray, lazy_empty
from cubed.types import T_RegularChunks, T_Shape, T_Store
from cubed.vendor.rechunker.algorithm import rechunking_plan
from cubed.vendor.rechunker.api import _validate_options


def rechunk(
source, target_chunks, allowed_mem, reserved_mem, target_store, temp_store=None
):
source: T_ZarrArray,
target_chunks: T_RegularChunks,
allowed_mem: int,
reserved_mem: int,
target_store: T_Store,
temp_store: Optional[T_Store] = None,
) -> CubedPipeline:
"""Rechunk a Zarr array to have target_chunks.
Parameters
Expand All @@ -36,7 +42,7 @@ def rechunk(

# rechunker doesn't take account of uncompressed and compressed copies of the
# input and output array chunk/selection, so adjust appropriately
rechunker_max_mem = (allowed_mem - reserved_mem) / 4
rechunker_max_mem = (allowed_mem - reserved_mem) // 4

copy_specs, intermediate, target = _setup_rechunk(
source=source,
Expand All @@ -62,14 +68,14 @@ def rechunk(

# from rechunker, but simpler since it only has to handle Zarr arrays
def _setup_rechunk(
source,
target_chunks,
max_mem,
target_store,
target_options=None,
temp_store=None,
temp_options=None,
):
source: T_ZarrArray,
target_chunks: T_RegularChunks,
max_mem: int,
target_store: T_Store,
target_options: Optional[Dict[Any, Any]] = None,
temp_store: Optional[T_Store] = None,
temp_options: Optional[Dict[Any, Any]] = None,
) -> Tuple[List[CubedCopySpec], T_ZarrArray, T_ZarrArray]:
if temp_options is None:
temp_options = target_options
target_options = target_options or {}
Expand All @@ -90,14 +96,14 @@ def _setup_rechunk(


def _setup_array_rechunk(
source_array,
target_chunks,
max_mem,
target_store_or_group,
target_options=None,
temp_store_or_group=None,
temp_options=None,
name=None,
source_array: T_ZarrArray,
target_chunks: T_RegularChunks,
max_mem: int,
target_store_or_group: T_Store,
target_options: Optional[Dict[Any, Any]] = None,
temp_store_or_group: Optional[T_Store] = None,
temp_options: Optional[Dict[Any, Any]] = None,
name: Optional[str] = None,
) -> CubedCopySpec:
_validate_options(target_options)
_validate_options(temp_options)
Expand All @@ -115,9 +121,6 @@ def _setup_array_rechunk(
# this is just a pass-through copy
target_chunks = source_chunks

# TODO: rewrite to avoid the hard dependency on dask
max_mem = cubed.vendor.dask.utils.parse_bytes(max_mem)

# don't consolidate reads for Dask arrays
consolidate_reads = isinstance(source_array, zarr.core.Array)
read_chunks, int_chunks, write_chunks = rechunking_plan(
Expand All @@ -143,11 +146,6 @@ def _setup_array_rechunk(
**(target_options or {}),
)

try:
target_array.attrs.update(source_array.attrs)
except AttributeError:
pass

if read_chunks == write_chunks:
int_array = None
else:
Expand All @@ -172,6 +170,6 @@ def _setup_array_rechunk(
return CubedCopySpec(read_proxy, int_proxy, write_proxy)


def total_chunks(shape, chunks):
def total_chunks(shape: T_Shape, chunks: T_RegularChunks) -> int:
# cf rechunker's chunk_keys
return prod(ceil(s / c) for s, c in zip(shape, chunks))
13 changes: 8 additions & 5 deletions cubed/primitive/types.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from dataclasses import dataclass
from typing import Any, Iterable, Optional
from typing import Any, Optional, Sequence

from cubed.storage.zarr import open_if_lazy_zarr_array
import zarr

from cubed.storage.zarr import T_ZarrArray, open_if_lazy_zarr_array
from cubed.types import T_RegularChunks
from cubed.vendor.rechunker.types import Config, Stage


@dataclass(frozen=True)
class CubedPipeline:
"""Generalisation of rechunker ``Pipeline`` with extra attributes."""

stages: Iterable[Stage]
stages: Sequence[Stage]
config: Config
target_array: Any
intermediate_array: Optional[Any]
Expand All @@ -20,11 +23,11 @@ class CubedPipeline:
class CubedArrayProxy:
"""Generalisation of rechunker ``ArrayProxy`` with support for ``LazyZarrArray``."""

def __init__(self, array, chunks):
def __init__(self, array: T_ZarrArray, chunks: T_RegularChunks):
self.array = array
self.chunks = chunks

def open(self):
def open(self) -> zarr.Array:
return open_if_lazy_zarr_array(self.array)


Expand Down
Loading

0 comments on commit fd32152

Please sign in to comment.