1
1
import itertools
2
+ import math
2
3
from dataclasses import dataclass
3
4
from functools import partial
4
- from typing import Callable , Dict
5
+ from typing import Any , Callable , Dict , Iterator , List , Optional , Sequence , Tuple , Union
5
6
6
7
import toolz
7
8
import zarr
8
9
from toolz import map
9
10
10
- from cubed .storage .zarr import lazy_empty
11
+ from cubed .storage .zarr import T_ZarrArray , lazy_empty
12
+ from cubed .types import T_Chunks , T_DType , T_Shape , T_Store
11
13
from cubed .utils import chunk_memory , get_item , to_chunksize
12
14
from cubed .vendor .dask .array .core import normalize_chunks
13
15
from cubed .vendor .dask .blockwise import _get_coord_mapping , _make_dims , lol_product
19
21
sym_counter = 0
20
22
21
23
22
- def gensym (name ) :
24
+ def gensym (name : str ) -> str :
23
25
global sym_counter
24
26
sym_counter += 1
25
27
return f"{ name } -{ sym_counter :03} "
@@ -43,19 +45,21 @@ class BlockwiseSpec:
43
45
Write proxy with an ``array`` attribute that supports ``__setitem__``.
44
46
"""
45
47
46
- block_function : Callable
47
- function : Callable
48
+ block_function : Callable [..., Any ]
49
+ function : Callable [..., Any ]
48
50
reads_map : Dict [str , CubedArrayProxy ]
49
51
write : CubedArrayProxy
50
52
51
53
52
- def apply_blockwise (out_key , * , config = BlockwiseSpec ):
54
+ def apply_blockwise (out_key : List [ int ] , * , config : BlockwiseSpec ) -> None :
53
55
"""Stage function for blockwise."""
54
56
# lithops needs params to be lists not tuples, so convert back
55
- out_key = tuple (out_key )
56
- out_chunk_key = key_to_slices (out_key , config .write .array , config .write .chunks )
57
+ out_key_tuple = tuple (out_key )
58
+ out_chunk_key = key_to_slices (
59
+ out_key_tuple , config .write .array , config .write .chunks
60
+ )
57
61
args = []
58
- name_chunk_inds = config .block_function (("out" ,) + out_key )
62
+ name_chunk_inds = config .block_function (("out" ,) + out_key_tuple )
59
63
for name_chunk_ind in name_chunk_inds :
60
64
name = name_chunk_ind [0 ]
61
65
chunk_ind = name_chunk_ind [1 :]
@@ -72,25 +76,27 @@ def apply_blockwise(out_key, *, config=BlockwiseSpec):
72
76
config .write .open ()[out_chunk_key ] = result
73
77
74
78
75
- def key_to_slices (key , arr , chunks = None ):
79
+ def key_to_slices (
80
+ key : Tuple [int , ...], arr : T_ZarrArray , chunks : Optional [T_Chunks ] = None
81
+ ) -> Tuple [slice , ...]:
76
82
"""Convert a chunk index key to a tuple of slices"""
77
83
chunks = normalize_chunks (chunks or arr .chunks , shape = arr .shape , dtype = arr .dtype )
78
84
return get_item (chunks , key )
79
85
80
86
81
87
def blockwise (
82
- func ,
83
- out_ind ,
84
- * args ,
85
- allowed_mem ,
86
- reserved_mem ,
87
- target_store ,
88
- shape ,
89
- dtype ,
90
- chunks ,
91
- new_axes = None ,
92
- in_names = None ,
93
- out_name = None ,
88
+ func : Callable [..., Any ] ,
89
+ out_ind : Sequence [ Union [ str , int ]] ,
90
+ * args : Any ,
91
+ allowed_mem : int ,
92
+ reserved_mem : int ,
93
+ target_store : T_Store ,
94
+ shape : T_Shape ,
95
+ dtype : T_DType ,
96
+ chunks : T_Chunks ,
97
+ new_axes : Optional [ Dict [ int , int ]] = None ,
98
+ in_names : Optional [ List [ str ]] = None ,
99
+ out_name : Optional [ str ] = None ,
94
100
** kwargs ,
95
101
):
96
102
"""Apply a function across blocks from multiple source Zarr arrays.
@@ -126,20 +132,20 @@ def blockwise(
126
132
"""
127
133
128
134
# Use dask's make_blockwise_graph
129
- arrays = args [::2 ]
135
+ arrays : Sequence [ T_ZarrArray ] = args [::2 ]
130
136
array_names = in_names or [f"in_{ i } " for i in range (len (arrays ))]
131
137
array_map = {name : array for name , array in zip (array_names , arrays )}
132
138
133
- inds = args [1 ::2 ]
139
+ inds : Sequence [ Union [ str , int ]] = args [1 ::2 ]
134
140
135
- numblocks = {}
141
+ numblocks : Dict [ str , Tuple [ int , ...]] = {}
136
142
for name , array in zip (array_names , arrays ):
137
143
input_chunks = normalize_chunks (
138
144
array .chunks , shape = array .shape , dtype = array .dtype
139
145
)
140
146
numblocks [name ] = tuple (map (len , input_chunks ))
141
147
142
- argindsstr = []
148
+ argindsstr : List [ Any ] = []
143
149
for name , ind in zip (array_names , inds ):
144
150
argindsstr .extend ((name , ind ))
145
151
@@ -228,21 +234,21 @@ def blockwise(
228
234
# Code for fusing pipelines
229
235
230
236
231
- def is_fuse_candidate (pipeline ) :
237
+ def is_fuse_candidate (pipeline : CubedPipeline ) -> bool :
232
238
"""
233
239
Return True if a pipeline is a candidate for blockwise fusion.
234
240
"""
235
241
stages = pipeline .stages
236
242
return len (stages ) == 1 and stages [0 ].function == apply_blockwise
237
243
238
244
239
- def can_fuse_pipelines (pipeline1 , pipeline2 ) :
245
+ def can_fuse_pipelines (pipeline1 : CubedPipeline , pipeline2 : CubedPipeline ) -> bool :
240
246
if is_fuse_candidate (pipeline1 ) and is_fuse_candidate (pipeline2 ):
241
247
return pipeline1 .num_tasks == pipeline2 .num_tasks
242
248
return False
243
249
244
250
245
- def fuse (pipeline1 , pipeline2 ) :
251
+ def fuse (pipeline1 : CubedPipeline , pipeline2 : CubedPipeline ) -> CubedPipeline :
246
252
"""
247
253
Fuse two blockwise pipelines into a single pipeline, avoiding writing to (or reading from) the target of the first pipeline.
248
254
"""
@@ -282,8 +288,13 @@ def fused_func(*args):
282
288
283
289
284
290
def make_blockwise_function (
285
- func , output , out_indices , * arrind_pairs , numblocks = None , new_axes = None
286
- ):
291
+ func : Callable [..., Any ],
292
+ output : str ,
293
+ out_indices : Sequence [Union [str , int ]],
294
+ * arrind_pairs : Any ,
295
+ numblocks : Optional [Dict [str , Tuple [int , ...]]] = None ,
296
+ new_axes : Optional [Dict [int , int ]] = None ,
297
+ ) -> Callable [[List [int ]], Any ]:
287
298
"""Make a function that is the equivalent of make_blockwise_graph."""
288
299
289
300
if numblocks is None :
@@ -335,8 +346,13 @@ def blockwise_fn(out_key):
335
346
336
347
337
348
def make_blockwise_function_flattened (
338
- func , output , out_indices , * arrind_pairs , numblocks = None , new_axes = None
339
- ):
349
+ func : Callable [..., Any ],
350
+ output : str ,
351
+ out_indices : Sequence [Union [str , int ]],
352
+ * arrind_pairs : Any ,
353
+ numblocks : Optional [Dict [str , Tuple [int , ...]]] = None ,
354
+ new_axes : Optional [Dict [int , int ]] = None ,
355
+ ) -> Callable [[List [int ]], Any ]:
340
356
# TODO: make this a part of make_blockwise_function?
341
357
blockwise_fn = make_blockwise_function (
342
358
func , output , out_indices , * arrind_pairs , numblocks = numblocks , new_axes = new_axes
@@ -353,8 +369,13 @@ def blockwise_fn_flattened(out_key):
353
369
354
370
355
371
def get_output_blocks (
356
- func , output , out_indices , * arrind_pairs , numblocks = None , new_axes = None
357
- ):
372
+ func : Callable [..., Any ],
373
+ output : str ,
374
+ out_indices : Sequence [Union [str , int ]],
375
+ * arrind_pairs : Any ,
376
+ numblocks : Optional [Dict [str , Tuple [int , ...]]] = None ,
377
+ new_axes : Optional [Dict [int , int ]] = None ,
378
+ ) -> Iterator [List [int ]]:
358
379
if numblocks is None :
359
380
raise ValueError ("Missing required numblocks argument." )
360
381
new_axes = new_axes or {}
@@ -369,24 +390,26 @@ def get_output_blocks(
369
390
370
391
371
392
class IterableFromGenerator :
372
- def __init__ (self , generator_fn ):
393
+ def __init__ (self , generator_fn : Callable [[], Iterator [ List [ int ]]] ):
373
394
self .generator_fn = generator_fn
374
395
375
396
def __iter__ (self ):
376
397
return self .generator_fn ()
377
398
378
399
379
400
def num_output_blocks (
380
- func , output , out_indices , * arrind_pairs , numblocks = None , new_axes = None
381
- ):
401
+ func : Callable [..., Any ],
402
+ output : str ,
403
+ out_indices : Sequence [Union [str , int ]],
404
+ * arrind_pairs : Any ,
405
+ numblocks : Optional [Dict [str , Tuple [int , ...]]] = None ,
406
+ new_axes : Optional [Dict [int , int ]] = None ,
407
+ ) -> int :
382
408
if numblocks is None :
383
409
raise ValueError ("Missing required numblocks argument." )
384
410
new_axes = new_axes or {}
385
411
argpairs = list (toolz .partition (2 , arrind_pairs ))
386
412
387
413
# Dictionary mapping {i: 3, j: 4, ...} for i, j, ... the dimensions
388
414
dims = _make_dims (argpairs , numblocks , new_axes )
389
-
390
- import math
391
-
392
415
return math .prod (dims [i ] for i in out_indices )
0 commit comments