Skip to content

Commit

Permalink
Rename block_function to key_function in blockwise (#426)
Browse files Browse the repository at this point in the history
Rename `make_blockwise_function` to `make_blockwise_key_function` in blockwise

Rename `name_chunk_ind` to `in_key`
  • Loading branch information
tomwhite authored Mar 15, 2024
1 parent e364761 commit 3fd2195
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 78 deletions.
8 changes: 4 additions & 4 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def reshape_chunks(x, shape, chunks):
# use an empty template (handles smaller end chunks)
template = empty(shape, dtype=x.dtype, chunks=chunks, spec=x.spec)

def block_function(out_key):
def key_function(out_key):
out_coords = out_key[1:]
offset = block_id_to_offset(out_coords, template.numblocks)
in_coords = offset_to_block_id(offset, x.numblocks)
Expand All @@ -261,7 +261,7 @@ def block_function(out_key):

return general_blockwise(
_reshape_chunk,
block_function,
key_function,
x,
template,
shape=shape,
Expand Down Expand Up @@ -291,7 +291,7 @@ def stack(arrays, /, *, axis=0):

array_names = [a.name for a in arrays]

def block_function(out_key):
def key_function(out_key):
out_coords = out_key[1:]
in_name = array_names[out_coords[axis]]
return ((in_name, *(out_coords[:axis] + out_coords[(axis + 1) :])),)
Expand All @@ -302,7 +302,7 @@ def block_function(out_key):
# assume they are the same. See https://github.com/cubed-dev/cubed/issues/414
return general_blockwise(
_read_stack_chunk,
block_function,
key_function,
*arrays,
shape=shape,
dtype=dtype,
Expand Down
14 changes: 7 additions & 7 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def blockwise(

def general_blockwise(
func,
block_function,
key_function,
*arrays,
shape,
dtype,
Expand Down Expand Up @@ -354,7 +354,7 @@ def general_blockwise(
target_store = new_temp_path(name=name, spec=spec)
op = primitive_general_blockwise(
func,
block_function,
key_function,
*zargs,
allowed_mem=spec.allowed_mem,
reserved_mem=spec.reserved_mem,
Expand Down Expand Up @@ -833,7 +833,7 @@ def merge_chunks_new(x, chunks):
i for (i, (c0, c1)) in enumerate(zip(x.chunksize, target_chunksize)) if c0 != c1
]

def block_function(out_key):
def key_function(out_key):
out_coords = out_key[1:]

in_keys = []
Expand All @@ -855,7 +855,7 @@ def block_function(out_key):

return general_blockwise(
_concatenate2,
block_function,
key_function,
x,
shape=x.shape,
dtype=x.dtype,
Expand Down Expand Up @@ -1109,7 +1109,7 @@ def partial_reduce(x, func, initial_func=None, split_every=None, dtype=None):
shape = tuple(map(sum, chunks))
axis = tuple(ax for ax in split_every.keys())

def block_function(out_key):
def key_function(out_key):
out_coords = out_key[1:]

# return a tuple with a single item that is an iterator of input keys to be merged
Expand All @@ -1124,7 +1124,7 @@ def block_function(out_key):
]
return (iter([(x.name,) + tuple(p) for p in product(*in_keys)]),)

# Since block_function returns an iterator of input keys, the the array chunks passed to
# Since key_function returns an iterator of input keys, the the array chunks passed to
# _partial_reduce are retrieved one at a time. However, we need an extra chunk of memory
# to stay within limits (maybe because the iterator doesn't free the previous object
# before getting the next). We also need extra memory to hold two reduced chunks, since
Expand All @@ -1133,7 +1133,7 @@ def block_function(out_key):

return general_blockwise(
_partial_reduce,
block_function,
key_function,
x,
shape=shape,
dtype=dtype,
Expand Down
87 changes: 43 additions & 44 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class BlockwiseSpec:
Attributes
----------
block_function : Callable
A function that maps an output chunk index to one or more input chunk indexes.
key_function : Callable
A function that maps an output chunk key to one or more input chunk keys.
function : Callable
A function that maps input chunks to an output chunk.
function_nargs: int
Expand All @@ -58,28 +58,29 @@ class BlockwiseSpec:
Write proxy with an ``array`` attribute that supports ``__setitem__``.
"""

block_function: Callable[..., Any]
key_function: Callable[..., Any]
function: Callable[..., Any]
function_nargs: int
num_input_blocks: Tuple[int, ...]
reads_map: Dict[str, CubedArrayProxy]
write: CubedArrayProxy


def apply_blockwise(out_key: List[int], *, config: BlockwiseSpec) -> None:
def apply_blockwise(out_coords: List[int], *, config: BlockwiseSpec) -> None:
"""Stage function for blockwise."""
# lithops needs params to be lists not tuples, so convert back
out_key_tuple = tuple(out_key)
out_coords_tuple = tuple(out_coords)
out_chunk_key = key_to_slices(
out_key_tuple, config.write.array, config.write.chunks
out_coords_tuple, config.write.array, config.write.chunks
)

# get array chunks for input keys, preserving any nested list structure
args = []
get_chunk_config = partial(get_chunk, config=config)
name_chunk_inds = config.block_function(("out",) + out_key_tuple)
for name_chunk_ind in name_chunk_inds:
arg = map_nested(get_chunk_config, name_chunk_ind)
out_key = ("out",) + out_coords_tuple # array name is ignored by key_function
in_keys = config.key_function(out_key)
for in_key in in_keys:
arg = map_nested(get_chunk_config, in_key)
args.append(arg)

result = config.function(*args)
Expand All @@ -100,13 +101,13 @@ def key_to_slices(
return get_item(chunks, key)


def get_chunk(name_chunk_ind, config):
def get_chunk(in_key, config):
"""Read a chunk from the named array"""
name = name_chunk_ind[0]
chunk_ind = name_chunk_ind[1:]
name = in_key[0]
in_coords = in_key[1:]
arr = config.reads_map[name].open()
chunk_key = key_to_slices(chunk_ind, arr)
arg = arr[chunk_key]
selection = key_to_slices(in_coords, arr)
arg = arr[selection]
arg = numpy_array_to_backend_array(arg)
return arg

Expand Down Expand Up @@ -188,7 +189,7 @@ def blockwise(
for name, ind in zip(array_names, inds):
argindsstr.extend((name, ind))

block_function = make_blockwise_function_flattened(
key_function = make_blockwise_key_function_flattened(
func,
out_name or "out",
out_ind,
Expand All @@ -199,7 +200,7 @@ def blockwise(

return general_blockwise(
func,
block_function,
key_function,
*arrays,
allowed_mem=allowed_mem,
reserved_mem=reserved_mem,
Expand All @@ -219,7 +220,7 @@ def blockwise(

def general_blockwise(
func: Callable[..., Any],
block_function: Callable[..., Any],
key_function: Callable[..., Any],
*arrays: Any,
allowed_mem: int,
reserved_mem: int,
Expand All @@ -242,8 +243,8 @@ def general_blockwise(
----------
func : callable
Function to apply to individual tuples of blocks
block_function : callable
A function that maps an output chunk index to one or more input chunk indexes.
key_function : callable
A function that maps an output chunk key to one or more input chunk keys.
*arrays : sequence of Array
The input arrays.
allowed_mem : int
Expand Down Expand Up @@ -291,7 +292,7 @@ def general_blockwise(
}
write_proxy = CubedArrayProxy(target_array, chunksize)
spec = BlockwiseSpec(
block_function,
key_function,
func_with_kwargs,
len(arrays),
num_input_blocks,
Expand Down Expand Up @@ -460,10 +461,8 @@ def fuse(

mappable = pipeline2.mappable

def fused_blockwise_func(out_key):
return pipeline1.config.block_function(
*pipeline2.config.block_function(out_key)
)
def fused_key_func(out_key):
return pipeline1.config.key_function(*pipeline2.config.key_function(out_key))

def fused_func(*args):
return pipeline2.config.function(pipeline1.config.function(*args))
Expand All @@ -476,7 +475,7 @@ def fused_func(*args):
for n in pipeline1.config.num_input_blocks
)
spec = BlockwiseSpec(
fused_blockwise_func,
fused_key_func,
fused_func,
function_nargs,
num_input_blocks,
Expand Down Expand Up @@ -530,36 +529,36 @@ def fuse_multiple(

mappable = pipeline.mappable

def apply_pipeline_block_func(pipeline, n_input_blocks, arg):
def apply_pipeline_key_func(pipeline, n_input_blocks, arg):
if pipeline is None:
return (arg,)
if n_input_blocks == 1:
assert isinstance(arg, tuple)
return pipeline.config.block_function(arg)
return pipeline.config.key_function(arg)
else:
# more than one input block is being read from arg
assert isinstance(arg, (list, Iterator))
if isinstance(arg, list):
return tuple(
list(item)
for item in zip(*(pipeline.config.block_function(a) for a in arg))
for item in zip(*(pipeline.config.key_function(a) for a in arg))
)
else:
# Return iterators to avoid materializing all array blocks at
# once.
return tuple(
iter(list(item))
for item in zip(*(pipeline.config.block_function(a) for a in arg))
for item in zip(*(pipeline.config.key_function(a) for a in arg))
)

def fused_blockwise_func(out_key):
def fused_key_func(out_key):
# this will change when multiple outputs are supported
args = pipeline.config.block_function(out_key)
args = pipeline.config.key_function(out_key)
# split all args to the fused function into groups, one for each predecessor function
func_args = tuple(
item
for i, (p, a) in enumerate(zip(predecessor_pipelines, args))
for item in apply_pipeline_block_func(
for item in apply_pipeline_key_func(
p, pipeline.config.num_input_blocks[i], a
)
)
Expand Down Expand Up @@ -602,7 +601,7 @@ def fused_func(*args):
read_proxies.update(p.config.reads_map)
write_proxy = pipeline.config.write
spec = BlockwiseSpec(
fused_blockwise_func,
fused_key_func,
fused_func,
fused_function_nargs,
fused_num_input_blocks,
Expand Down Expand Up @@ -643,10 +642,10 @@ def fused_func(*args):
)


# blockwise functions
# blockwise key functions


def make_blockwise_function(
def make_blockwise_key_function(
func: Callable[..., Any],
output: str,
out_indices: Sequence[Union[str, int]],
Expand Down Expand Up @@ -675,7 +674,7 @@ def make_blockwise_function(
False,
)

def blockwise_fn(out_key):
def key_function(out_key):
out_coords = out_key[1:]

# from Dask make_blockwise_graph
Expand All @@ -701,27 +700,27 @@ def blockwise_fn(out_key):

return val

return blockwise_fn
return key_function


def make_blockwise_function_flattened(
def make_blockwise_key_function_flattened(
func: Callable[..., Any],
output: str,
out_indices: Sequence[Union[str, int]],
*arrind_pairs: Any,
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
new_axes: Optional[Dict[int, int]] = None,
) -> Callable[[List[int]], Any]:
# TODO: make this a part of make_blockwise_function?
blockwise_fn = make_blockwise_function(
# TODO: make this a part of make_blockwise_key_function?
key_function = make_blockwise_key_function(
func, output, out_indices, *arrind_pairs, numblocks=numblocks, new_axes=new_axes
)

def blockwise_fn_flattened(out_key):
name_chunk_inds = blockwise_fn(out_key)[1:] # drop function in position 0
in_keys = key_function(out_key)[1:] # drop function in position 0
# flatten (nested) lists indicating contraction
if isinstance(name_chunk_inds[0], list):
name_chunk_inds = list(flatten(name_chunk_inds))
return name_chunk_inds
if isinstance(in_keys[0], list):
in_keys = list(flatten(in_keys))
return in_keys

return blockwise_fn_flattened
Loading

0 comments on commit 3fd2195

Please sign in to comment.