Skip to content

Commit 3fd2195

Browse files
authored
Rename block_function to key_function in blockwise (#426)
Rename `make_blockwise_function` to `make_blockwise_key_function` in blockwise Rename `name_chunk_ind` to `in_key`
1 parent e364761 commit 3fd2195

File tree

4 files changed

+77
-78
lines changed

4 files changed

+77
-78
lines changed

cubed/array_api/manipulation_functions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def reshape_chunks(x, shape, chunks):
250250
# use an empty template (handles smaller end chunks)
251251
template = empty(shape, dtype=x.dtype, chunks=chunks, spec=x.spec)
252252

253-
def block_function(out_key):
253+
def key_function(out_key):
254254
out_coords = out_key[1:]
255255
offset = block_id_to_offset(out_coords, template.numblocks)
256256
in_coords = offset_to_block_id(offset, x.numblocks)
@@ -261,7 +261,7 @@ def block_function(out_key):
261261

262262
return general_blockwise(
263263
_reshape_chunk,
264-
block_function,
264+
key_function,
265265
x,
266266
template,
267267
shape=shape,
@@ -291,7 +291,7 @@ def stack(arrays, /, *, axis=0):
291291

292292
array_names = [a.name for a in arrays]
293293

294-
def block_function(out_key):
294+
def key_function(out_key):
295295
out_coords = out_key[1:]
296296
in_name = array_names[out_coords[axis]]
297297
return ((in_name, *(out_coords[:axis] + out_coords[(axis + 1) :])),)
@@ -302,7 +302,7 @@ def block_function(out_key):
302302
# assume they are the same. See https://github.com/cubed-dev/cubed/issues/414
303303
return general_blockwise(
304304
_read_stack_chunk,
305-
block_function,
305+
key_function,
306306
*arrays,
307307
shape=shape,
308308
dtype=dtype,

cubed/core/ops.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def blockwise(
325325

326326
def general_blockwise(
327327
func,
328-
block_function,
328+
key_function,
329329
*arrays,
330330
shape,
331331
dtype,
@@ -354,7 +354,7 @@ def general_blockwise(
354354
target_store = new_temp_path(name=name, spec=spec)
355355
op = primitive_general_blockwise(
356356
func,
357-
block_function,
357+
key_function,
358358
*zargs,
359359
allowed_mem=spec.allowed_mem,
360360
reserved_mem=spec.reserved_mem,
@@ -833,7 +833,7 @@ def merge_chunks_new(x, chunks):
833833
i for (i, (c0, c1)) in enumerate(zip(x.chunksize, target_chunksize)) if c0 != c1
834834
]
835835

836-
def block_function(out_key):
836+
def key_function(out_key):
837837
out_coords = out_key[1:]
838838

839839
in_keys = []
@@ -855,7 +855,7 @@ def block_function(out_key):
855855

856856
return general_blockwise(
857857
_concatenate2,
858-
block_function,
858+
key_function,
859859
x,
860860
shape=x.shape,
861861
dtype=x.dtype,
@@ -1109,7 +1109,7 @@ def partial_reduce(x, func, initial_func=None, split_every=None, dtype=None):
11091109
shape = tuple(map(sum, chunks))
11101110
axis = tuple(ax for ax in split_every.keys())
11111111

1112-
def block_function(out_key):
1112+
def key_function(out_key):
11131113
out_coords = out_key[1:]
11141114

11151115
# return a tuple with a single item that is an iterator of input keys to be merged
@@ -1124,7 +1124,7 @@ def block_function(out_key):
11241124
]
11251125
return (iter([(x.name,) + tuple(p) for p in product(*in_keys)]),)
11261126

1127-
# Since block_function returns an iterator of input keys, the the array chunks passed to
1127+
# Since key_function returns an iterator of input keys, the the array chunks passed to
11281128
# _partial_reduce are retrieved one at a time. However, we need an extra chunk of memory
11291129
# to stay within limits (maybe because the iterator doesn't free the previous object
11301130
# before getting the next). We also need extra memory to hold two reduced chunks, since
@@ -1133,7 +1133,7 @@ def block_function(out_key):
11331133

11341134
return general_blockwise(
11351135
_partial_reduce,
1136-
block_function,
1136+
key_function,
11371137
x,
11381138
shape=shape,
11391139
dtype=dtype,

cubed/primitive/blockwise.py

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ class BlockwiseSpec:
4444
4545
Attributes
4646
----------
47-
block_function : Callable
48-
A function that maps an output chunk index to one or more input chunk indexes.
47+
key_function : Callable
48+
A function that maps an output chunk key to one or more input chunk keys.
4949
function : Callable
5050
A function that maps input chunks to an output chunk.
5151
function_nargs: int
@@ -58,28 +58,29 @@ class BlockwiseSpec:
5858
Write proxy with an ``array`` attribute that supports ``__setitem__``.
5959
"""
6060

61-
block_function: Callable[..., Any]
61+
key_function: Callable[..., Any]
6262
function: Callable[..., Any]
6363
function_nargs: int
6464
num_input_blocks: Tuple[int, ...]
6565
reads_map: Dict[str, CubedArrayProxy]
6666
write: CubedArrayProxy
6767

6868

69-
def apply_blockwise(out_key: List[int], *, config: BlockwiseSpec) -> None:
69+
def apply_blockwise(out_coords: List[int], *, config: BlockwiseSpec) -> None:
7070
"""Stage function for blockwise."""
7171
# lithops needs params to be lists not tuples, so convert back
72-
out_key_tuple = tuple(out_key)
72+
out_coords_tuple = tuple(out_coords)
7373
out_chunk_key = key_to_slices(
74-
out_key_tuple, config.write.array, config.write.chunks
74+
out_coords_tuple, config.write.array, config.write.chunks
7575
)
7676

7777
# get array chunks for input keys, preserving any nested list structure
7878
args = []
7979
get_chunk_config = partial(get_chunk, config=config)
80-
name_chunk_inds = config.block_function(("out",) + out_key_tuple)
81-
for name_chunk_ind in name_chunk_inds:
82-
arg = map_nested(get_chunk_config, name_chunk_ind)
80+
out_key = ("out",) + out_coords_tuple # array name is ignored by key_function
81+
in_keys = config.key_function(out_key)
82+
for in_key in in_keys:
83+
arg = map_nested(get_chunk_config, in_key)
8384
args.append(arg)
8485

8586
result = config.function(*args)
@@ -100,13 +101,13 @@ def key_to_slices(
100101
return get_item(chunks, key)
101102

102103

103-
def get_chunk(name_chunk_ind, config):
104+
def get_chunk(in_key, config):
104105
"""Read a chunk from the named array"""
105-
name = name_chunk_ind[0]
106-
chunk_ind = name_chunk_ind[1:]
106+
name = in_key[0]
107+
in_coords = in_key[1:]
107108
arr = config.reads_map[name].open()
108-
chunk_key = key_to_slices(chunk_ind, arr)
109-
arg = arr[chunk_key]
109+
selection = key_to_slices(in_coords, arr)
110+
arg = arr[selection]
110111
arg = numpy_array_to_backend_array(arg)
111112
return arg
112113

@@ -188,7 +189,7 @@ def blockwise(
188189
for name, ind in zip(array_names, inds):
189190
argindsstr.extend((name, ind))
190191

191-
block_function = make_blockwise_function_flattened(
192+
key_function = make_blockwise_key_function_flattened(
192193
func,
193194
out_name or "out",
194195
out_ind,
@@ -199,7 +200,7 @@ def blockwise(
199200

200201
return general_blockwise(
201202
func,
202-
block_function,
203+
key_function,
203204
*arrays,
204205
allowed_mem=allowed_mem,
205206
reserved_mem=reserved_mem,
@@ -219,7 +220,7 @@ def blockwise(
219220

220221
def general_blockwise(
221222
func: Callable[..., Any],
222-
block_function: Callable[..., Any],
223+
key_function: Callable[..., Any],
223224
*arrays: Any,
224225
allowed_mem: int,
225226
reserved_mem: int,
@@ -242,8 +243,8 @@ def general_blockwise(
242243
----------
243244
func : callable
244245
Function to apply to individual tuples of blocks
245-
block_function : callable
246-
A function that maps an output chunk index to one or more input chunk indexes.
246+
key_function : callable
247+
A function that maps an output chunk key to one or more input chunk keys.
247248
*arrays : sequence of Array
248249
The input arrays.
249250
allowed_mem : int
@@ -291,7 +292,7 @@ def general_blockwise(
291292
}
292293
write_proxy = CubedArrayProxy(target_array, chunksize)
293294
spec = BlockwiseSpec(
294-
block_function,
295+
key_function,
295296
func_with_kwargs,
296297
len(arrays),
297298
num_input_blocks,
@@ -460,10 +461,8 @@ def fuse(
460461

461462
mappable = pipeline2.mappable
462463

463-
def fused_blockwise_func(out_key):
464-
return pipeline1.config.block_function(
465-
*pipeline2.config.block_function(out_key)
466-
)
464+
def fused_key_func(out_key):
465+
return pipeline1.config.key_function(*pipeline2.config.key_function(out_key))
467466

468467
def fused_func(*args):
469468
return pipeline2.config.function(pipeline1.config.function(*args))
@@ -476,7 +475,7 @@ def fused_func(*args):
476475
for n in pipeline1.config.num_input_blocks
477476
)
478477
spec = BlockwiseSpec(
479-
fused_blockwise_func,
478+
fused_key_func,
480479
fused_func,
481480
function_nargs,
482481
num_input_blocks,
@@ -530,36 +529,36 @@ def fuse_multiple(
530529

531530
mappable = pipeline.mappable
532531

533-
def apply_pipeline_block_func(pipeline, n_input_blocks, arg):
532+
def apply_pipeline_key_func(pipeline, n_input_blocks, arg):
534533
if pipeline is None:
535534
return (arg,)
536535
if n_input_blocks == 1:
537536
assert isinstance(arg, tuple)
538-
return pipeline.config.block_function(arg)
537+
return pipeline.config.key_function(arg)
539538
else:
540539
# more than one input block is being read from arg
541540
assert isinstance(arg, (list, Iterator))
542541
if isinstance(arg, list):
543542
return tuple(
544543
list(item)
545-
for item in zip(*(pipeline.config.block_function(a) for a in arg))
544+
for item in zip(*(pipeline.config.key_function(a) for a in arg))
546545
)
547546
else:
548547
# Return iterators to avoid materializing all array blocks at
549548
# once.
550549
return tuple(
551550
iter(list(item))
552-
for item in zip(*(pipeline.config.block_function(a) for a in arg))
551+
for item in zip(*(pipeline.config.key_function(a) for a in arg))
553552
)
554553

555-
def fused_blockwise_func(out_key):
554+
def fused_key_func(out_key):
556555
# this will change when multiple outputs are supported
557-
args = pipeline.config.block_function(out_key)
556+
args = pipeline.config.key_function(out_key)
558557
# split all args to the fused function into groups, one for each predecessor function
559558
func_args = tuple(
560559
item
561560
for i, (p, a) in enumerate(zip(predecessor_pipelines, args))
562-
for item in apply_pipeline_block_func(
561+
for item in apply_pipeline_key_func(
563562
p, pipeline.config.num_input_blocks[i], a
564563
)
565564
)
@@ -602,7 +601,7 @@ def fused_func(*args):
602601
read_proxies.update(p.config.reads_map)
603602
write_proxy = pipeline.config.write
604603
spec = BlockwiseSpec(
605-
fused_blockwise_func,
604+
fused_key_func,
606605
fused_func,
607606
fused_function_nargs,
608607
fused_num_input_blocks,
@@ -643,10 +642,10 @@ def fused_func(*args):
643642
)
644643

645644

646-
# blockwise functions
645+
# blockwise key functions
647646

648647

649-
def make_blockwise_function(
648+
def make_blockwise_key_function(
650649
func: Callable[..., Any],
651650
output: str,
652651
out_indices: Sequence[Union[str, int]],
@@ -675,7 +674,7 @@ def make_blockwise_function(
675674
False,
676675
)
677676

678-
def blockwise_fn(out_key):
677+
def key_function(out_key):
679678
out_coords = out_key[1:]
680679

681680
# from Dask make_blockwise_graph
@@ -701,27 +700,27 @@ def blockwise_fn(out_key):
701700

702701
return val
703702

704-
return blockwise_fn
703+
return key_function
705704

706705

707-
def make_blockwise_function_flattened(
706+
def make_blockwise_key_function_flattened(
708707
func: Callable[..., Any],
709708
output: str,
710709
out_indices: Sequence[Union[str, int]],
711710
*arrind_pairs: Any,
712711
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
713712
new_axes: Optional[Dict[int, int]] = None,
714713
) -> Callable[[List[int]], Any]:
715-
# TODO: make this a part of make_blockwise_function?
716-
blockwise_fn = make_blockwise_function(
714+
# TODO: make this a part of make_blockwise_key_function?
715+
key_function = make_blockwise_key_function(
717716
func, output, out_indices, *arrind_pairs, numblocks=numblocks, new_axes=new_axes
718717
)
719718

720719
def blockwise_fn_flattened(out_key):
721-
name_chunk_inds = blockwise_fn(out_key)[1:] # drop function in position 0
720+
in_keys = key_function(out_key)[1:] # drop function in position 0
722721
# flatten (nested) lists indicating contraction
723-
if isinstance(name_chunk_inds[0], list):
724-
name_chunk_inds = list(flatten(name_chunk_inds))
725-
return name_chunk_inds
722+
if isinstance(in_keys[0], list):
723+
in_keys = list(flatten(in_keys))
724+
return in_keys
726725

727726
return blockwise_fn_flattened

0 commit comments

Comments
 (0)