diff --git a/README.rst b/README.rst index 05873b6fb..c9023fd03 100644 --- a/README.rst +++ b/README.rst @@ -33,7 +33,3 @@ Pytato is written to pose no particular restrictions on the version of numpy used for execution. To use mypy-based type checking on Pytato itself or packages using Pytato, numpy 1.20 or newer is required, due to the typing-based changes to numpy in that release. - -Furthermore, pytato now uses type promotion rules based on those in -`numpy `__ that should result in the same -data types as the currently installed version of numpy. diff --git a/examples/advection.py b/examples/advection.py index 339ff80a8..fd308ae50 100755 --- a/examples/advection.py +++ b/examples/advection.py @@ -156,6 +156,7 @@ def test_advection_convergence(order, flux_type): op = AdvectionOperator(discr, c=1, flux_type=flux_type, dg_ops=dg_ops) result = op.apply(u) + result = pt.transform.Deduplicator()(result) prog = pt.generate_loopy(result, cl_device=queue.device) diff --git a/pytato/__init__.py b/pytato/__init__.py index 56254b928..c6066ac99 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -158,7 +158,13 @@ def set_debug_enabled(flag: bool) -> None: from pytato.target.loopy import LoopyPyOpenCLTarget from pytato.target.loopy.codegen import generate_loopy from pytato.target.python.jax import generate_jax -from pytato.transform.calls import inline_calls, tag_all_calls_to_be_inlined +from pytato.transform import precompute_subexpressions +from pytato.transform.calls import ( + concatenate_calls, + inline_calls, + tag_all_calls_to_be_inlined, + zero_unused_call_bindings, +) from pytato.transform.lower_to_index_lambda import to_index_lambda from pytato.transform.metadata import unify_axes_tags from pytato.transform.remove_broadcasts_einsum import rewrite_einsums_with_no_broadcasts @@ -215,6 +221,7 @@ def set_debug_enabled(flag: bool) -> None: "arctan2", "broadcast_to", "concatenate", + "concatenate_calls", "conj", "cos", "cosh", @@ -260,6 +267,7 @@ def set_debug_enabled(flag: bool) -> None: "number_distributed_tags", "ones", "pad", + "precompute_subexpressions", "prod", "real", "reshape", @@ -287,5 +295,6 @@ def set_debug_enabled(flag: bool) -> None: "vdot", "verify_distributed_partition", "where", + "zero_unused_call_bindings", "zeros", ) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index e1487b710..75eeb2723 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -28,15 +28,19 @@ from typing import TYPE_CHECKING, Any, Never -from orderedsets import FrozenOrderedSet +from orderedsets import FrozenOrderedSet, OrderedSet from typing_extensions import Self from loopy.tools import LoopyKeyBuilder from pymbolic.mapper.optimize import optimize_mapper +from pytools import memoize_method +from loopy.tools import LoopyKeyBuilder from pytato.array import ( Array, + AxisPermutation, Concatenate, + DataWrapper, DictOfNamedArrays, Einsum, IndexBase, @@ -44,20 +48,34 @@ IndexRemappingBase, InputArgumentBase, NamedArray, + Placeholder, + Reshape, + Roll, ShapeType, + SizeParam, Stack, ) from pytato.function import Call, FunctionDefinition, NamedCallResult -from pytato.transform import ArrayOrNames, CachedWalkMapper, CombineMapper, Mapper +from pytato.transform import ( + ArrayOrNames, + CachedWalkMapper, + CombineMapper, + IndexOrShapeExpr, + Mapper, +) if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + from collections.abc import Callable, Iterable, Mapping import pytools from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder - from pytato.loopy import LoopyCall + from pytato.loopy import LoopyCall, LoopyCallResult + + +# FIXME: Think about whether this makes sense +NodeT = Array | FunctionDefinition | Call __doc__ = """ .. currentmodule:: pytato.analysis @@ -74,6 +92,12 @@ .. autofunction:: get_num_call_sites +.. autofunction:: collect_nodes_of_type + +.. autofunction:: collect_materialized_nodes + +.. autofunction:: trace_dependencies + .. autoclass:: DirectPredecessorsGetter .. autoclass:: TagCountMapper @@ -83,6 +107,7 @@ # {{{ NUserCollector +# FIXME: Use ordered sets class NUserCollector(Mapper[None, None, []]): """ A :class:`pytato.transform.CachedWalkMapper` that records the number of @@ -226,6 +251,7 @@ def get_nusers(outputs: Array | DictOfNamedArrays) -> Mapping[Array, int]: # {{{ is_einsum_similar_to_subscript +# FIXME: Use ordered sets def _get_indices_from_input_subscript(subscript: str, is_output: bool, ) -> tuple[str, ...]: @@ -323,7 +349,11 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter -class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]): +class DirectPredecessorsGetter( + Mapper[ + FrozenOrderedSet[ArrayOrNames | FunctionDefinition], + FrozenOrderedSet[ArrayOrNames], + []]): """ Mapper to get the `direct predecessors @@ -334,9 +364,17 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]): We only consider the predecessors of a nodes in a data-flow sense. """ + def __init__(self, *, include_functions: bool = False) -> None: + super().__init__() + self.include_functions = include_functions + def _get_preds_from_shape(self, shape: ShapeType) -> FrozenOrderedSet[ArrayOrNames]: return FrozenOrderedSet(dim for dim in shape if isinstance(dim, Array)) + def map_dict_of_named_arrays( + self, expr: DictOfNamedArrays) -> FrozenOrderedSet[ArrayOrNames]: + return FrozenOrderedSet(expr._data.values()) + def map_index_lambda(self, expr: IndexLambda) -> FrozenOrderedSet[ArrayOrNames]: return (FrozenOrderedSet(expr.bindings.values()) | self._get_preds_from_shape(expr.shape)) @@ -397,8 +435,17 @@ def map_distributed_send_ref_holder(self, ) -> FrozenOrderedSet[ArrayOrNames]: return FrozenOrderedSet([expr.passthrough_data]) - def map_call(self, expr: Call) -> FrozenOrderedSet[ArrayOrNames]: - return FrozenOrderedSet(expr.bindings.values()) + def map_call( + self, expr: Call) -> FrozenOrderedSet[ArrayOrNames | FunctionDefinition]: + result: FrozenOrderedSet[ArrayOrNames | FunctionDefinition] = \ + FrozenOrderedSet(expr.bindings.values()) + if self.include_functions: + result = result | FrozenOrderedSet([expr.function]) + return result + + def map_function_definition( + self, expr: FunctionDefinition) -> FrozenOrderedSet[ArrayOrNames]: + return FrozenOrderedSet(expr.returns.values()) def map_named_call_result( self, expr: NamedCallResult) -> FrozenOrderedSet[ArrayOrNames]: @@ -416,20 +463,28 @@ class NodeCountMapper(CachedWalkMapper[[]]): Counts the number of nodes of a given type in a DAG. .. autoattribute:: expr_type_counts + + Dictionary mapping node types to number of nodes of that type. + .. autoattribute:: count_duplicates - Dictionary mapping node types to number of nodes of that type. + If `True`, counts each array instance as a separate node, even if some are + equal. + """ def __init__( self, count_duplicates: bool = False, - _visited_functions: set[Any] | None = None, + traverse_functions: bool = True, + _visited_functions: OrderedSet[Any] | None = None, ) -> None: super().__init__(_visited_functions=_visited_functions) + self.traverse_functions = traverse_functions + from collections import defaultdict - self.expr_type_counts: dict[type[Any], int] = defaultdict(int) + self.expr_type_counts: dict[type[NodeT], int] = defaultdict(int) self.count_duplicates = count_duplicates def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames: @@ -444,17 +499,34 @@ def get_function_definition_cache_key( def clone_for_callee(self, function: FunctionDefinition) -> Self: return type(self)( count_duplicates=self.count_duplicates, + traverse_functions=self.traverse_functions, _visited_functions=self._visited_functions) + def visit(self, expr: Any) -> bool: + return not isinstance(expr, FunctionDefinition) or self.traverse_functions + def post_visit(self, expr: Any) -> None: - if not isinstance(expr, DictOfNamedArrays): + if isinstance(expr, NodeT): self.expr_type_counts[type(expr)] += 1 + def map_function_definition(self, expr: FunctionDefinition) -> None: + if not self.visit(expr): + return + + new_mapper = self.clone_for_callee(expr) + for ret in expr.returns.values(): + new_mapper(ret) + + for node_type, count in new_mapper.expr_type_counts.items(): + self.expr_type_counts[node_type] += count + + self.post_visit(expr) + def get_node_type_counts( outputs: Array | DictOfNamedArrays, count_duplicates: bool = False - ) -> dict[type[Any], int]: + ) -> dict[type[NodeT], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. @@ -473,7 +545,8 @@ def get_node_type_counts( def get_num_nodes( outputs: Array | DictOfNamedArrays, - count_duplicates: bool | None = None + count_duplicates: bool | None = None, + traverse_functions: bool = True ) -> int: """ Returns the number of nodes in DAG *outputs*. @@ -491,11 +564,38 @@ def get_num_nodes( from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) - ncm = NodeCountMapper(count_duplicates) + ncm = NodeCountMapper( + count_duplicates=count_duplicates, + traverse_functions=traverse_functions) ncm(outputs) return sum(ncm.expr_type_counts.values()) + +def get_num_node_instances( + outputs: Array | DictOfNamedArrays, + node_type: type[NodeT], + strict: bool = True, + count_duplicates: bool = False) -> int: + """ + Returns the number of nodes in DAG *outputs* that have type *node_type* (if + *strict* is `True`) or are instances of *node_type* (if *strict* is `False`). + """ + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + ncm = NodeCountMapper(count_duplicates) + ncm(outputs) + + if strict: + return ncm.expr_type_counts[node_type] + else: + return sum( + count + for other_node_type, count in ncm.expr_type_counts.items() + if isinstance(other_node_type, node_type)) + # }}} @@ -511,11 +611,16 @@ class NodeMultiplicityMapper(CachedWalkMapper[[]]): .. autoattribute:: expr_multiplicity_counts """ - def __init__(self, _visited_functions: set[Any] | None = None) -> None: + def __init__( + self, + traverse_functions: bool = True, + _visited_functions: OrderedSet[Any] | None = None) -> None: super().__init__(_visited_functions=_visited_functions) + self.traverse_functions = traverse_functions + from collections import defaultdict - self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int) + self.expr_multiplicity_counts: dict[NodeT, int] = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> int: # Returns each node, including nodes that are duplicates @@ -525,13 +630,29 @@ def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: # Returns each node, including nodes that are duplicates return id(expr) + def visit(self, expr: Any) -> bool: + return not isinstance(expr, FunctionDefinition) or self.traverse_functions + def post_visit(self, expr: Any) -> None: - if not isinstance(expr, DictOfNamedArrays): + if isinstance(expr, NodeT): self.expr_multiplicity_counts[expr] += 1 + def map_function_definition(self, expr: FunctionDefinition) -> None: + if not self.visit(expr): + return + + new_mapper = self.clone_for_callee(expr) + for ret in expr.returns.values(): + new_mapper(ret) + + for subexpr, count in new_mapper.expr_multiplicity_counts.items(): + self.expr_multiplicity_counts[subexpr] += count + + self.post_visit(expr) + def get_node_multiplicities( - outputs: Array | DictOfNamedArrays) -> dict[Array, int]: + outputs: Array | DictOfNamedArrays) -> dict[NodeT, int]: """ Returns the multiplicity per `expr`. """ @@ -558,7 +679,7 @@ class CallSiteCountMapper(CachedWalkMapper[[]]): The number of nodes. """ - def __init__(self, _visited_functions: set[Any] | None = None) -> None: + def __init__(self, _visited_functions: OrderedSet[Any] | None = None) -> None: super().__init__(_visited_functions=_visited_functions) self.count = 0 @@ -568,21 +689,21 @@ def get_cache_key(self, expr: ArrayOrNames) -> int: def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: return id(expr) + def post_visit(self, expr: Any) -> None: + if isinstance(expr, Call): + self.count += 1 + def map_function_definition(self, expr: FunctionDefinition) -> None: if not self.visit(expr): return new_mapper = self.clone_for_callee(expr) - for subexpr in expr.returns.values(): - new_mapper(subexpr) + for ret in expr.returns.values(): + new_mapper(ret) self.count += new_mapper.count self.post_visit(expr) - def post_visit(self, expr: Any) -> None: - if isinstance(expr, Call): - self.count += 1 - def get_num_call_sites(outputs: Array | DictOfNamedArrays) -> int: """Returns the number of nodes in DAG *outputs*.""" @@ -600,6 +721,7 @@ def get_num_call_sites(outputs: Array | DictOfNamedArrays) -> int: # {{{ TagCountMapper +# FIXME: Use ordered sets class TagCountMapper(CombineMapper[int, Never]): """ Returns the number of nodes in a DAG that are tagged with all the tag types in @@ -622,9 +744,9 @@ def combine(self, *args: int) -> int: return sum(args) def rec(self, expr: ArrayOrNames) -> int: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache_retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, @@ -639,7 +761,7 @@ def rec(self, expr: ArrayOrNames) -> int: else: result = 0 + s - self._cache.add(expr, 0, key=key) + self._cache_add(inputs, 0) return result @@ -658,16 +780,13 @@ def get_num_tags_of_type( # {{{ PytatoKeyBuilder -class PytatoKeyBuilder(LoopyKeyBuilder): + +class PytatoKeyBuilder(LoopyKeyBuilder): # type: ignore[misc] """A custom :class:`pytools.persistent_dict.KeyBuilder` subclass for objects within :mod:`pytato`. """ - # The types below aren't immutable in general, but in the context of - # pytato, they are used as such. def update_for_ndarray(self, key_hash: Any, key: Any) -> None: - import numpy as np - assert isinstance(key, np.ndarray) self.rec(key_hash, key.data.tobytes()) def update_for_TaggableCLArray(self, key_hash: Any, key: Any) -> None: @@ -678,10 +797,266 @@ def update_for_TaggableCLArray(self, key_hash: Any, key: Any) -> None: self.rec(key_hash, key.get()) def update_for_Array(self, key_hash: Any, key: Any) -> None: - from pyopencl.array import Array - assert isinstance(key, Array) + # CL Array self.rec(key_hash, key.get()) +# }}} + + +# {{{ NodeCollector + +# FIXME: Decide if this should be a CombineMapper instead? +@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) +class NodeCollector(CachedWalkMapper[[]]): + """ + Collects all nodes matching specified criteria in a DAG. + + .. attribute:: nodes + + The collected nodes. + """ + + def __init__( + self, + collect_func: Callable[[NodeT], bool], + traverse_functions: bool = True) -> None: + super().__init__() + self.collect_func = collect_func + self.traverse_functions = traverse_functions + self.nodes: OrderedSet[NodeT] = OrderedSet() + + def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: + return expr + + def get_function_definition_cache_key( + self, expr: FunctionDefinition) -> FunctionDefinition: + return expr + + def clone_for_callee( + self: NodeCollector, function: FunctionDefinition) -> NodeCollector: + return type(self)(self.collect_func) + + def visit(self, expr: Any) -> bool: + return not isinstance(expr, FunctionDefinition) or self.traverse_functions + + def post_visit(self, expr: Any) -> None: + if isinstance(expr, NodeT) and self.collect_func(expr): + self.nodes.add(expr) + + def map_function_definition(self, expr: FunctionDefinition) -> None: + if not self.visit(expr): + return + + new_mapper = self.clone_for_callee(expr) + for ret in expr.returns.values(): + new_mapper(ret) + + self.nodes |= new_mapper.nodes + + self.post_visit(expr) + + +def collect_nodes_of_type( + outputs: Array | DictOfNamedArrays, + node_type: type[NodeT]) -> FrozenOrderedSet[NodeT]: + """Returns the nodes that are instances of *node_type* in DAG *outputs*.""" + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + def collect_func(expr: NodeT) -> bool: + return isinstance(expr, node_type) + + nc = NodeCollector(collect_func) + nc(outputs) + + return FrozenOrderedSet(nc.nodes) + + +def collect_materialized_nodes( + outputs: Array | DictOfNamedArrays) -> FrozenOrderedSet[NodeT]: + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + def collect_func(expr: NodeT) -> bool: + from pytato.tags import ImplStored + return bool(expr.tags_of_type(ImplStored)) + + nc = NodeCollector(collect_func) + nc(outputs) + + return FrozenOrderedSet(nc.nodes) # }}} + +# {{{ DependencyTracer + +# FIXME: Use ordered sets +class DependencyTracer(CombineMapper[frozenset[tuple[Array, ...]], Never]): + """ + Maps a DAG and a node to a :class:`frozenset` of `tuple`\\ s of + :class:`pytato.array.Array`\\ s representing dependency traces from + the node to one of the DAG outputs. + + .. note:: + + Does not recurse into function definitions. + """ + def __init__(self, dependee: Array) -> None: + super().__init__() + self.dependee = dependee + + def rec_idx_or_size_tuple( + self, situp: tuple[IndexOrShapeExpr, ...] + ) -> tuple[frozenset[tuple[Array, ...]], ...]: + return tuple(self.rec(s) for s in situp if isinstance(s, Array)) + + def combine( + self, *args: frozenset[tuple[Array, ...]]) -> frozenset[tuple[Array, ...]]: + from functools import reduce + # FIXME: This doesn't match the docs (original version produced way too + # many results) + combined: frozenset[tuple[Array, ...]] = reduce( + lambda a, b: a | b, args, frozenset()) + if combined: + return frozenset({next(iter(combined))}) + else: + return frozenset() + + def map_index_lambda(self, expr: IndexLambda) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_index_lambda(expr))) + + def map_placeholder(self, expr: Placeholder) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_placeholder(expr))) + + def map_data_wrapper(self, expr: DataWrapper) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_data_wrapper(expr))) + + def map_size_param(self, expr: SizeParam) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_size_param(expr))) + + def map_stack(self, expr: Stack) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_stack(expr))) + + def map_roll(self, expr: Roll) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_roll(expr))) + + def map_axis_permutation( + self, expr: AxisPermutation) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_axis_permutation(expr))) + + def _map_index_base(self, expr: IndexBase) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super()._map_index_base(expr))) + + def map_reshape(self, expr: Reshape) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_reshape(expr))) + + def map_concatenate(self, expr: Concatenate) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_concatenate(expr))) + + def map_einsum(self, expr: Einsum) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_einsum(expr))) + + def map_named_array(self, expr: NamedArray) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_named_array(expr))) + + def map_loopy_call(self, expr: LoopyCall) -> frozenset[tuple[Array, ...]]: + raise AssertionError("Control shouldn't reach this point.") + + def map_loopy_call_result( + self, expr: LoopyCallResult) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_loopy_call_result(expr))) + + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_distributed_send_ref_holder(expr))) + + def map_distributed_recv( + self, expr: DistributedRecv) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_distributed_recv(expr))) + + def map_call(self, expr: Call) -> frozenset[tuple[Array, ...]]: + return self.combine(*(self.rec(bnd) + for bnd in expr.bindings.values())) + + def map_named_call_result( + self, expr: NamedCallResult) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_named_call_result(expr))) + + +def trace_dependencies( + outputs: Array | DictOfNamedArrays, dependee: Array + ) -> frozenset[tuple[Array, ...]]: + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + dt = DependencyTracer(dependee) + return dt(outputs) + +# }}} + + # vim: fdm=marker diff --git a/pytato/array.py b/pytato/array.py index 8f92e5118..267bdeaa5 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1,4 +1,5 @@ from __future__ import annotations +from traceback import FrameSummary, StackSummary __copyright__ = """ @@ -299,7 +300,7 @@ def array_dataclass(hash: bool = True) -> Callable[[type[T]], type[T]]: def map_cls(cls: type[T]) -> type[T]: # Frozen dataclasses (empirically) have a ~20% speed penalty, # and their frozen-ness is arguably a debug feature. - dc_cls = dataclasses.dataclass(init=True, frozen=__debug__, + dc_cls = dataclasses.dataclass(init=True, frozen=True, eq=False, repr=False)(cls) _augment_array_dataclass(dc_cls, generate_hash=hash) @@ -359,6 +360,9 @@ def {cls.__name__}_hash(self): cls.__hash__ = {cls.__name__}_hash + # cls.__getstate__ = dataclasses._dataclass_getstate + # cls.__setstate__ = dataclasses._dataclass_setstate + # By default (when slots=False), dataclasses do not have special # handling for pickling, thus using pickle's default behavior that # looks at obj.__dict__. This would also pickle the cached hash, @@ -411,20 +415,31 @@ def _dataclass_setstate(self, state): # {{{ array interface + ConvertibleToIndexExpr = Union[int, slice, "Array", EllipsisType, None] +# IndexExpr = Union[IntegerT, "NormalizedSlice", "Array", None, EllipsisType] +# IndexExpr = Union[IntegerT, "NormalizedSlice", "Array", None] +# DtypeOrScalar = Union[_dtype_any, ScalarT] +# ArrayOrScalar = Union["Array", ScalarT] +DtypeOrScalar = Union[_dtype_any, Scalar] +ArrayOrScalar = Union["Array", Scalar] IndexExpr = Union[Integer, "NormalizedSlice", "Array", None] PyScalarType = type[bool] | type[int] | type[float] | type[complex] DtypeOrPyScalarType = _dtype_any | PyScalarType -def _np_result_dtype( - *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, +# https://github.com/numpy/numpy/issues/19302 +def _np_result_type( + # actual dtype: + #*arrays_and_dtypes: Union[np.typing.ArrayLike, np.typing.DTypeLike], + # our dtype: + *arrays_and_dtypes: DtypeOrScalar, ) -> np.dtype[Any]: return np.result_type(*arrays_and_dtypes) -def _truediv_result_type(*dtypes: DtypeOrPyScalarType) -> np.dtype[Any]: - dtype = _np_result_dtype(*dtypes) +def _truediv_result_type(arg1: DtypeOrScalar, arg2: DtypeOrScalar) -> np.dtype[Any]: + dtype = _np_result_type(arg1, arg2) # See: test_true_divide in numpy/core/tests/test_ufunc.py # pylint: disable=no-member if dtype.kind in "iu": @@ -467,8 +482,11 @@ class Axis(Taggable): tags: frozenset[Tag] def _with_new_tags(self, tags: frozenset[Tag]) -> Axis: - from dataclasses import replace - return replace(self, tags=tags) + if tags != self.tags: + from dataclasses import replace + return replace(self, tags=tags) + else: + return self @dataclasses.dataclass(frozen=True) @@ -480,8 +498,11 @@ class ReductionDescriptor(Taggable): tags: frozenset[Tag] def _with_new_tags(self, tags: frozenset[Tag]) -> ReductionDescriptor: - from dataclasses import replace - return replace(self, tags=tags) + if tags != self.tags: + from dataclasses import replace + return replace(self, tags=tags) + else: + return self @array_dataclass() @@ -686,11 +707,12 @@ def __matmul__(self, other: Array, reverse: bool = False) -> Array: def _binary_op( self, - op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], + op: Callable[[ScalarExpression, ScalarExpression], + ScalarExpression], other: ArrayOrScalar, get_result_type: Callable[ [ArrayOrScalar, ArrayOrScalar], - np.dtype[Any]] = _np_result_dtype, + np.dtype[Any]] = _np_result_type, reverse: bool = False, cast_to_result_dtype: bool = True, is_pow: bool = False, @@ -746,21 +768,33 @@ def _unary_op(self, op: Any) -> Array: non_equality_tags=_get_created_at_tag(), var_to_reduction_descr=immutabledict()) - __mul__ = partialmethod(_binary_op, operator.mul) - __rmul__ = partialmethod(_binary_op, operator.mul, reverse=True) + # NOTE: Initializing the expression to "prim.Product(expr1, expr2)" is + # essential as opposed to performing "expr1 * expr2". This is to account + # for pymbolic's implementation of the "*" operator which might not + # instantiate the node corresponding to the operation when one of + # the operands is the neutral element of the operation. + # + # For the same reason 'prim.(Sum|FloorDiv|Quotient)' is preferred over the + # python operators on the operands. + + __mul__ = partialmethod(_binary_op, lambda l, r: prim.Product((l, r))) + __rmul__ = partialmethod(_binary_op, lambda l, r: prim.Product((l, r)), + reverse=True) - __add__ = partialmethod(_binary_op, operator.add) - __radd__ = partialmethod(_binary_op, operator.add, reverse=True) + __add__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, r))) + __radd__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, r)), + reverse=True) - __sub__ = partialmethod(_binary_op, operator.sub) - __rsub__ = partialmethod(_binary_op, operator.sub, reverse=True) + __sub__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, -r))) + __rsub__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, -r)), + reverse=True) - __floordiv__ = partialmethod(_binary_op, operator.floordiv) - __rfloordiv__ = partialmethod(_binary_op, operator.floordiv, reverse=True) + __floordiv__ = partialmethod(_binary_op, prim.FloorDiv) + __rfloordiv__ = partialmethod(_binary_op, prim.FloorDiv, reverse=True) - __truediv__ = partialmethod(_binary_op, operator.truediv, + __truediv__ = partialmethod(_binary_op, prim.Quotient, get_result_type=_truediv_result_type) - __rtruediv__ = partialmethod(_binary_op, operator.truediv, + __rtruediv__ = partialmethod(_binary_op, prim.Quotient, get_result_type=_truediv_result_type, reverse=True) __mod__ = partialmethod(_binary_op, operator.mod) @@ -833,10 +867,14 @@ def with_tagged_axis(self, iaxis: int, """ Returns a copy of *self* with *iaxis*-th axis tagged with *tags*. """ - new_axes = (self.axes[:iaxis] - + (self.axes[iaxis].tagged(tags),) - + self.axes[iaxis+1:]) - return self.copy(axes=new_axes) + new_axis = self.axes[iaxis].tagged(tags) + if new_axis is not self.axes[iaxis]: + new_axes = (self.axes[:iaxis] + + (self.axes[iaxis].tagged(tags),) + + self.axes[iaxis+1:]) + return self.copy(axes=new_axes) + else: + return self @memoize_method def __repr__(self) -> str: @@ -865,7 +903,10 @@ class _SuppliedAxesAndTagsMixin(Taggable): default=frozenset()) def _with_new_tags(self: Self, tags: frozenset[Tag]) -> Self: - return dataclasses.replace(self, tags=tags) + if tags != self.tags: + return dataclasses.replace(self, tags=tags) + else: + return self @dataclasses.dataclass(frozen=True, eq=False, repr=False) @@ -1114,20 +1155,22 @@ def with_tagged_reduction(self, f" '{self.var_to_reduction_descr.keys()}'," f" got '{reduction_variable}'.") - assert isinstance(self.var_to_reduction_descr, immutabledict) - new_var_to_redn_descr = dict(self.var_to_reduction_descr) - new_var_to_redn_descr[reduction_variable] = \ - self.var_to_reduction_descr[reduction_variable].tagged(tags) - - return type(self)(expr=self.expr, - shape=self.shape, - dtype=self.dtype, - bindings=self.bindings, - axes=self.axes, - var_to_reduction_descr=immutabledict - (new_var_to_redn_descr), - tags=self.tags, - non_equality_tags=self.non_equality_tags) + new_redn_descr = self.var_to_reduction_descr[reduction_variable].tagged(tags) + if new_redn_descr is not self.var_to_reduction_descr[reduction_variable]: + assert isinstance(self.var_to_reduction_descr, immutabledict) + new_var_to_redn_descr = dict(self.var_to_reduction_descr) + new_var_to_redn_descr[reduction_variable] = new_redn_descr + return type(self)(expr=self.expr, + shape=self.shape, + dtype=self.dtype, + bindings=self.bindings, + axes=self.axes, + var_to_reduction_descr=immutabledict + (new_var_to_redn_descr), + tags=self.tags, + non_equality_tags=self.non_equality_tags) + else: + return self # }}} @@ -1162,6 +1205,34 @@ class EinsumReductionAxis(EinsumAxisDescriptor): dim: int +def _get_einsum_access_descr_to_axis_len( + access_descriptors: tuple[tuple[EinsumAxisDescriptor, ...], ...], + args: tuple[Array, ...], + ) -> Mapping[EinsumAxisDescriptor, ShapeComponent]: + from pytato.utils import are_shape_components_equal + descr_to_axis_len: dict[EinsumAxisDescriptor, ShapeComponent] = {} + + for access_descrs, arg in zip(access_descriptors, + args, strict=True): + assert arg.ndim == len(access_descrs) + for arg_axis_len, descr in zip(arg.shape, access_descrs, strict=True): + if descr in descr_to_axis_len: + seen_axis_len = descr_to_axis_len[descr] + + if not are_shape_components_equal(seen_axis_len, + arg_axis_len): + if are_shape_components_equal(arg_axis_len, 1): + # this axis would be broadcasted + pass + else: + assert are_shape_components_equal(seen_axis_len, 1) + descr_to_axis_len[descr] = arg_axis_len + else: + descr_to_axis_len[descr] = arg_axis_len + + return immutabledict(descr_to_axis_len) + + @array_dataclass() class Einsum(_SuppliedAxesAndTagsMixin, Array): """ @@ -1209,28 +1280,8 @@ def __post_init__(self) -> None: @memoize_method def _access_descr_to_axis_len(self ) -> Mapping[EinsumAxisDescriptor, ShapeComponent]: - from pytato.utils import are_shape_components_equal - descr_to_axis_len: dict[EinsumAxisDescriptor, ShapeComponent] = {} - - for access_descrs, arg in zip(self.access_descriptors, - self.args, strict=True): - assert arg.ndim == len(access_descrs) - for arg_axis_len, descr in zip(arg.shape, access_descrs, strict=True): - if descr in descr_to_axis_len: - seen_axis_len = descr_to_axis_len[descr] - - if not are_shape_components_equal(seen_axis_len, - arg_axis_len): - if are_shape_components_equal(arg_axis_len, 1): - # this axis would be broadcasted - pass - else: - assert are_shape_components_equal(seen_axis_len, 1) - descr_to_axis_len[descr] = arg_axis_len - else: - descr_to_axis_len[descr] = arg_axis_len - - return immutabledict(descr_to_axis_len) + return _get_einsum_access_descr_to_axis_len( + self.access_descriptors, self.args) @cached_property def shape(self) -> ShapeType: @@ -1278,19 +1329,21 @@ def with_tagged_reduction(self, # }}} - assert isinstance(self.redn_axis_to_redn_descr, immutabledict) - new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr) - new_redn_axis_to_redn_descr[redn_axis] = \ - self.redn_axis_to_redn_descr[redn_axis].tagged(tags) - - return type(self)(access_descriptors=self.access_descriptors, - args=self.args, - axes=self.axes, - redn_axis_to_redn_descr=immutabledict - (new_redn_axis_to_redn_descr), - tags=self.tags, - non_equality_tags=self.non_equality_tags, - ) + new_redn_descr = self.redn_axis_to_redn_descr[redn_axis].tagged(tags) + if new_redn_descr is not self.redn_axis_to_redn_descr[redn_axis]: + assert isinstance(self.redn_axis_to_redn_descr, immutabledict) + new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr) + new_redn_axis_to_redn_descr[redn_axis] = new_redn_descr + return type(self)(access_descriptors=self.access_descriptors, + args=self.args, + axes=self.axes, + redn_axis_to_redn_descr=immutabledict + (new_redn_axis_to_redn_descr), + tags=self.tags, + non_equality_tags=self.non_equality_tags, + ) + else: + return self EINSUM_FIRST_INDEX = re.compile(r"^\s*((?P[a-zA-Z])|(?P\.\.\.))\s*") @@ -1521,7 +1574,7 @@ class Stack(_SuppliedAxesAndTagsMixin, Array): @property def dtype(self) -> np.dtype[Any]: - return _np_result_dtype(*(arr.dtype for arr in self.arrays)) + return _np_result_type(*(arr.dtype for arr in self.arrays)) @property def shape(self) -> ShapeType: @@ -1552,7 +1605,7 @@ class Concatenate(_SuppliedAxesAndTagsMixin, Array): @property def dtype(self) -> np.dtype[Any]: - return _np_result_dtype(*(arr.dtype for arr in self.arrays)) + return _np_result_type(*(arr.dtype for arr in self.arrays)) @property def shape(self) -> ShapeType: @@ -1993,6 +2046,25 @@ def _get_created_at_tag(stacklevel: int = 1) -> frozenset[Tag]: return frozenset({CreatedAt(_PytatoStackSummary(frames))}) +def _inherit_created_at_tag_from(ary: Array, src_ary: Array) -> Array: + from pytato.tags import CreatedAt + try: + tb_tag = next( + tag for tag in src_ary.non_equality_tags + if isinstance(tag, CreatedAt)) + except StopIteration: + tb_tag = None + + if tb_tag is not None: + return attrs.evolve( + ary, + non_equality_tags=frozenset({ + tb_tag if isinstance(tag, CreatedAt) else tag + for tag in ary.non_equality_tags})) + else: + return ary + + def _get_default_tags() -> frozenset[Tag]: return frozenset() @@ -2164,9 +2236,9 @@ def reshape(array: Array, newshape: int | Sequence[int], *and* the output array are linearized according to this order and 'matched up'. - Groups are found by multiplying axis lengths on the input and output side, - a matching input/output group is found once adding an input or axis to the - group makes the two products match. + Groups are found by multiplying axis lengths on the input and output + side, a matching input/output group is found once adding an input or + axis to the group makes the two products match. The semantics are identical to :func:`numpy.reshape`. diff --git a/pytato/codegen.py b/pytato/codegen.py index 86a328929..f01296e1b 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -138,57 +138,75 @@ def __init__( self, target: Target, kernels_seen: dict[str, lp.LoopKernel] | None = None, - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) + super().__init__( + # ToIndexLambdaMixin operates on certain array types for which `shape` + # is a derived property (e.g. BasicIndex). For these types, `shape` + # is an expression that may contain duplicate nodes. Mappers do not + # traverse properties, so these expressions are not subject to any prior + # deduplication. Once transformed into an IndexLambda, however, `shape` + # becomes a field and is subject to traversal and duplication checks. + # Without `err_on_collision=False`, these duplicates would lead to + # collision errors. + err_on_collision=False, + _cache=_cache, _function_cache=_function_cache) self.bound_arguments: dict[str, DataInterface] = {} self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator() self.target = target self.kernels_seen: dict[str, lp.LoopKernel] = kernels_seen or {} def map_size_param(self, expr: SizeParam) -> Array: - name = expr.name - assert name is not None - return SizeParam( # pylint: disable=missing-kwoa - name=name, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + assert expr.name is not None + return expr def map_placeholder(self, expr: Placeholder) -> Array: - name = expr.name - if name is None: - name = self.var_name_gen("_pt_in") - return Placeholder(name=name, - shape=tuple(self.rec(s) if isinstance(s, Array) else s - for s in expr.shape), - dtype=expr.dtype, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_name = expr.name + if new_name is None: + new_name = self.var_name_gen("_pt_in") + new_shape = self.rec_idx_or_size_tuple(expr.shape) + if ( + new_name is expr.name + and new_shape is expr.shape): + return expr + else: + return Placeholder(name=new_name, + shape=new_shape, + dtype=expr.dtype, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: from pytato.target.loopy import LoopyTarget if not isinstance(self.target, LoopyTarget): raise ValueError("Got a LoopyCall for a non-loopy target.") - translation_unit = expr.translation_unit.copy( - target=self.target.get_loopy_target()) + new_target = self.target.get_loopy_target() + + # FIXME: Can't use "is" here because targets aren't unique. Is it OK to + # use the existing target if it's equal to self.target.get_loopy_target()? + # If not, may have to set err_on_created_duplicate=False + if new_target == expr.translation_unit.target: + new_translation_unit = expr.translation_unit + else: + new_translation_unit = expr.translation_unit.copy(target=new_target) namegen = UniqueNameGenerator(set(self.kernels_seen)) - entrypoint = expr.entrypoint + new_entrypoint = expr.entrypoint # {{{ eliminate callable name collision - for name, clbl in translation_unit.callables_table.items(): + for name, clbl in new_translation_unit.callables_table.items(): if isinstance(clbl, lp.CallableKernel): assert isinstance(name, str) if name in self.kernels_seen and ( - translation_unit[name] != self.kernels_seen[name]): + new_translation_unit[name] != self.kernels_seen[name]): # callee name collision => must rename # {{{ see if it's one of the other kernels for other_knl in self.kernels_seen.values(): - if other_knl.copy(name=name) == translation_unit[name]: + if other_knl.copy(name=name) == new_translation_unit[name]: new_name = other_knl.name break else: @@ -198,37 +216,55 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: # }}} - if name == entrypoint: + if name == new_entrypoint: # if the colliding name is the entrypoint, then rename the # entrypoint as well. - entrypoint = new_name + new_entrypoint = new_name - translation_unit = lp.rename_callable( - translation_unit, name, new_name) + new_translation_unit = lp.rename_callable( + new_translation_unit, name, new_name) name = new_name self.kernels_seen[name] = clbl.subkernel # }}} - bindings: Mapping[str, Any] = immutabledict( + new_bindings: Mapping[str, Any] = immutabledict( {name: (self.rec(subexpr) if isinstance(subexpr, Array) else subexpr) for name, subexpr in sorted(expr.bindings.items())}) - return LoopyCall(translation_unit=translation_unit, - bindings=bindings, - entrypoint=entrypoint, - tags=expr.tags - ) + assert ( + new_entrypoint is expr.entrypoint + or new_entrypoint != expr.entrypoint) + for bnd, new_bnd in zip( + expr.bindings.values(), new_bindings.values(), strict=True): + assert new_bnd is bnd or new_bnd != bnd + + if ( + new_translation_unit == expr.translation_unit + and ( + frozenset(new_bindings.keys()) + == frozenset(expr.bindings.keys())) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings) + and new_entrypoint is expr.entrypoint): + return expr + else: + return LoopyCall(translation_unit=new_translation_unit, + bindings=new_bindings, + entrypoint=new_entrypoint, + tags=expr.tags + ) def map_data_wrapper(self, expr: DataWrapper) -> Array: name = _generate_name_for_temp(expr, self.var_name_gen, "_pt_data") + shape = self.rec_idx_or_size_tuple(expr.shape) self.bound_arguments[name] = expr.data return Placeholder(name=name, - shape=tuple(self.rec(s) if isinstance(s, Array) else s - for s in expr.shape), + shape=shape, dtype=expr.dtype, axes=expr.axes, tags=expr.tags, diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 741e36548..98f8a9027 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -240,9 +240,9 @@ def __init__(self, recvd_ary_to_name: Mapping[Array, str], sptpo_ary_to_name: Mapping[Array, str], name_to_output: Mapping[str, Array], - _cache: TransformMapperCache[ArrayOrNames] | None = None, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition] | None = None, + TransformMapperCache[FunctionDefinition, []] | None = None, ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) @@ -261,7 +261,7 @@ def clone_for_callee( return type(self)( {}, {}, {}, _function_cache=cast( - "TransformMapperCache[FunctionDefinition]", self._function_cache)) + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) def map_placeholder(self, expr: Placeholder) -> Placeholder: self.user_input_names.add(expr.name) @@ -294,9 +294,9 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend: return new_send def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache_retrieve(inputs) except KeyError: pass @@ -329,6 +329,8 @@ def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: class _PartCommIDs: """A *part*, unlike a *batch*, begins with receives and ends with sends. """ + # recv_ids: immutabledict[CommunicationOpIdentifier, None] + # send_ids: immutabledict[CommunicationOpIdentifier, None] recv_ids: FrozenOrderedSet[CommunicationOpIdentifier] send_ids: FrozenOrderedSet[CommunicationOpIdentifier] @@ -489,6 +491,9 @@ def map_named_call_result(self, expr: NamedCallResult) \ # {{{ _schedule_task_batches (and related) def _schedule_task_batches( + # Production + # task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ + # -> Sequence[dict[TaskType, None]]: task_ids_to_needed_task_ids: Mapping[TaskType, Set[TaskType]]) \ -> Sequence[OrderedSet[TaskType]]: """For each :type:`TaskType`, determine the @@ -504,6 +509,9 @@ def _schedule_task_batches( # {{{ _schedule_task_batches_counted def _schedule_task_batches_counted( + # Production + # task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ + # -> tuple[Sequence[dict[TaskType, None]], int]: task_ids_to_needed_task_ids: Mapping[TaskType, Set[TaskType]]) \ -> tuple[Sequence[OrderedSet[TaskType]], int]: """ @@ -518,7 +526,9 @@ def _schedule_task_batches_counted( [OrderedSet() for _ in range(nlevels)] for task_id, dep_level in task_to_dep_level.items(): - task_batches[dep_level].add(task_id) + if task_id not in task_batches[dep_level]: + # task_batches[dep_level][task_id] = None + task_batches[dep_level].add(task_id) return task_batches, visits_in_depend + len(task_to_dep_level.keys()) @@ -591,14 +601,17 @@ def post_visit(self, expr: Any) -> None: from pytato.tags import ImplStored if (isinstance(expr, Array) and expr.tags_of_type(ImplStored)): + # self.materialized_arrays[expr] = None self.materialized_arrays.add(expr) if isinstance(expr, LoopyCallResult): + # self.materialized_arrays[expr] = None self.materialized_arrays.add(expr) from pytato.loopy import LoopyCall assert isinstance(expr._container, LoopyCall) for _, subexpr in sorted(expr._container.bindings.items()): if isinstance(subexpr, Array): + # self.materialized_arrays[subexpr] = None self.materialized_arrays.add(subexpr) else: assert isinstance(subexpr, SCALAR_CLASSES) @@ -742,6 +755,7 @@ def find_distributed_partition( assigned in :attr:`DistributedGraphPart.name_to_send_nodes`. """ import mpi4py.MPI as MPI + from immutabledict import immutabledict from pytato.transform import SubsetDependencyMapper @@ -783,7 +797,10 @@ def find_distributed_partition( comm_batches_or_exc = mpi_communicator.bcast(None) if isinstance(comm_batches_or_exc, Exception): raise comm_batches_or_exc - + # Production + # comm_batches = cast( + # "Sequence[Set[CommunicationOpIdentifier]]", + # comm_batches_or_exc) comm_batches = comm_batches_or_exc # }}} @@ -791,6 +808,7 @@ def find_distributed_partition( # {{{ create (local) parts out of batch ids part_comm_ids: list[_PartCommIDs] = [] + if comm_batches: recv_ids: FrozenOrderedSet[CommunicationOpIdentifier] = FrozenOrderedSet() for batch in comm_batches: @@ -803,7 +821,7 @@ def find_distributed_partition( recv_ids=recv_ids, send_ids=send_ids)) # These go into the next part - recv_ids = FrozenOrderedSet( + recv_ids = immutabledict.fromkeys( comm_id for comm_id in batch if comm_id.dest_rank == local_rank) if recv_ids: @@ -811,7 +829,8 @@ def find_distributed_partition( _PartCommIDs( recv_ids=recv_ids, send_ids=FrozenOrderedSet())) - else: + + if not part_comm_ids: part_comm_ids.append( _PartCommIDs( recv_ids=FrozenOrderedSet(), @@ -842,7 +861,7 @@ def find_distributed_partition( materialized_arrays_collector = _MaterializedArrayCollector() materialized_arrays_collector(outputs) - # The sets of arrays below must have a deterministic order in order to ensure + # The collections of arrays below must have a deterministic order in order to ensure # that the resulting partition is also deterministic sent_arrays = FrozenOrderedSet( @@ -856,13 +875,12 @@ def find_distributed_partition( # We could allow sent *arrays* to be included here because they are distinct # from send *nodes*, but we choose to exclude them in order to simplify the # processing below. - materialized_arrays = ( - materialized_arrays_collector.materialized_arrays - - received_arrays - - sent_arrays) + materialized_arrays = {a: None + for a in materialized_arrays_collector.materialized_arrays + if a not in received_arrays | sent_arrays} # "mso" for "materialized/sent/output" - output_arrays = FrozenOrderedSet(outputs._data.values()) + output_arrays = dict.fromkeys(outputs._data.values()) mso_arrays = materialized_arrays | sent_arrays | output_arrays # FIXME: This gathers up materialized_arrays recursively, leading to @@ -939,9 +957,12 @@ def get_materialized_predecessors(ary: Array) -> OrderedSet[Array]: for pred in direct_preds_getter(ary): assert isinstance(pred, Array) if pred in materialized_arrays: + # materialized_preds[pred] = None materialized_preds.add(pred) else: - materialized_preds |= get_materialized_predecessors(pred) + for p in get_materialized_predecessors(pred): + # materialized_preds[p] = None + materialized_preds.add(p) return materialized_preds stored_arrays_promoted_to_part_outputs = FrozenOrderedSet( @@ -950,7 +971,7 @@ def get_materialized_predecessors(ary: Array) -> OrderedSet[Array]: for stored_pred in get_materialized_predecessors(stored_ary) if (stored_ary_to_part_id[stored_ary] != stored_ary_to_part_id[stored_pred]) - ) + ) # }}} diff --git a/pytato/equality.py b/pytato/equality.py index 47bf7a0dc..4b95df358 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -47,7 +47,8 @@ SizeParam, Stack, ) -from pytato.function import FunctionDefinition +from pytato.function import Call, FunctionDefinition, NamedCallResult +from pytato.tags import Tag if TYPE_CHECKING: @@ -59,6 +60,7 @@ __doc__ = """ .. autoclass:: EqualityComparer +.. autoclass:: SimilarityComparer """ @@ -328,4 +330,293 @@ def map_named_call_result(self, expr1: NamedCallResult, expr2: Any) -> bool: # }}} + +# {{{ SimilarityComparer + +class SimilarityComparer: + """ + A :class:`pytato.array.Array` visitor to check structural similarity between two + expression DAGs. Data and array shapes are allowed to be different. + + .. note:: + + - Compares two expression graphs ``expr1``, ``expr2`` in :math:`O(N)` + comparisons, where :math:`N` is the number of nodes in ``expr1``. + - This visitor was introduced to memoize the sub-expression comparisons + of the expressions to be compared. Not memoizing the sub-expression + comparisons results in :math:`O(2^N)` complexity for the comparison + operation, where :math:`N` is the number of nodes in expressions. See + `GH-Issue-163 ` for + more on this. + """ + def __init__( + self, + # FIXME: tuple? + ignore_tag_types: frozenset(type) | None = None, + err_on_not_similar: bool = False) -> None: + # Uses the same cache for both arrays and functions + self._cache: dict[tuple[int, int], bool] = {} + if ignore_tag_types is None: + ignore_tag_types: frozenset(type) = frozenset() + self.ignore_tag_types = tuple(ignore_tag_types) + self.err_on_not_similar = err_on_not_similar + + def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: Any) -> bool: + cache_key = id(expr1), id(expr2) + try: + return self._cache[cache_key] + except KeyError: + method: Callable[ + [Array | AbstractResultWithNamedArrays | FunctionDefinition, Any], + bool] + + try: + method = ( + getattr(self, expr1._mapper_method) + if isinstance(expr1, (Array, AbstractResultWithNamedArrays)) + else self.map_function_definition) + except AttributeError: + if isinstance(expr1, Array): + result = self.handle_unsupported_array(expr1, expr2) + else: + result = self.map_foreign(expr1, expr2) + else: + result = (expr1 is expr2) or method(expr1, expr2) + + if self.err_on_not_similar and not result: + raise ValueError(f"Not similar, {type(expr1).__name__}, {type(expr2).__name__}") + + self._cache[cache_key] = result + return result + + def __call__(self, expr1: ArrayOrNames, expr2: Any + ) -> bool: + return self.rec(expr1, expr2) + + def handle_unsupported_array(self, expr1: Array, + expr2: Any) -> bool: + raise NotImplementedError(type(expr1).__name__) + + def map_foreign(self, expr1: Any, expr2: Any) -> bool: + raise NotImplementedError(type(expr1).__name__) + + def _map_tags(self, tags1: frozenset(Tag), tags2: frozenset(Tag)) -> bool: + filtered_tags1 = frozenset( + tag for tag in tags1 if not isinstance(tag, self.ignore_tag_types)) + filtered_tags2 = frozenset( + tag for tag in tags2 if not isinstance(tag, self.ignore_tag_types)) + return filtered_tags1 == filtered_tags2 + + def map_placeholder(self, expr1: Placeholder, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.name == expr2.name + and len(expr1.shape) == len(expr2.shape) + and expr1.dtype == expr2.dtype + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def map_size_param(self, expr1: SizeParam, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.name == expr2.name + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def map_data_wrapper(self, expr1: DataWrapper, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.data.__class__ is expr2.data.__class__ + and expr1.name == expr2.name + and len(expr1.shape) == len(expr2.shape) + and all(self.rec(dim1, dim2) + for dim1, dim2 in zip(expr1.shape, expr2.shape) + if isinstance(dim1, Array)) + and expr1.dtype == expr2.dtype + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def map_index_lambda(self, expr1: IndexLambda, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.expr == expr2.expr + and (frozenset(expr1.bindings.keys()) + == frozenset(expr2.bindings.keys())) + and all(self.rec(expr1.bindings[name], expr2.bindings[name]) + for name in expr1.bindings) + and len(expr1.shape) == len(expr2.shape) + and all(self.rec(dim1, dim2) + for dim1, dim2 in zip(expr1.shape, expr2.shape) + if isinstance(dim1, Array)) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + and expr1.var_to_reduction_descr == expr2.var_to_reduction_descr + ) + + def map_stack(self, expr1: Stack, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.axis == expr2.axis + and len(expr1.arrays) == len(expr2.arrays) + and all(self.rec(ary1, ary2) + for ary1, ary2 in zip(expr1.arrays, expr2.arrays)) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def map_concatenate(self, expr1: Concatenate, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.axis == expr2.axis + and len(expr1.arrays) == len(expr2.arrays) + and all(self.rec(ary1, ary2) + for ary1, ary2 in zip(expr1.arrays, expr2.arrays)) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def map_roll(self, expr1: Roll, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.axis == expr2.axis + and expr1.shift == expr2.shift + and self.rec(expr1.array, expr2.array) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def map_axis_permutation(self, expr1: AxisPermutation, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.axis_permutation == expr2.axis_permutation + and self.rec(expr1.array, expr2.array) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def _map_index_base(self, expr1: IndexBase, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and self.rec(expr1.array, expr2.array) + and len(expr1.indices) == len(expr2.indices) + and all(self.rec(idx1, idx2) + if (isinstance(idx1, Array) + and isinstance(idx2, Array)) + else idx1 == idx2 + for idx1, idx2 in zip(expr1.indices, expr2.indices)) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def map_basic_index(self, expr1: BasicIndex, expr2: Any) -> bool: + return self._map_index_base(expr1, expr2) + + def map_contiguous_advanced_index(self, + expr1: AdvancedIndexInContiguousAxes, + expr2: Any + ) -> bool: + return self._map_index_base(expr1, expr2) + + def map_non_contiguous_advanced_index(self, + expr1: AdvancedIndexInNoncontiguousAxes, + expr2: Any + ) -> bool: + return self._map_index_base(expr1, expr2) + + def map_reshape(self, expr1: Reshape, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and len(expr1.newshape) == len(expr2.newshape) + and self.rec(expr1.array, expr2.array) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def map_einsum(self, expr1: Einsum, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.access_descriptors == expr2.access_descriptors + and all(self.rec(ary1, ary2) + for ary1, ary2 in zip(expr1.args, + expr2.args)) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + and expr1.redn_axis_to_redn_descr == expr2.redn_axis_to_redn_descr + ) + + def map_named_array(self, expr1: NamedArray, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and self.rec(expr1._container, expr2._container) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + and expr1.name == expr2.name) + + def map_loopy_call(self, expr1: LoopyCall, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.translation_unit == expr2.translation_unit + and expr1.entrypoint == expr2.entrypoint + and frozenset(expr1.bindings) == frozenset(expr2.bindings) + and all(self.rec(bnd, + expr2.bindings[name]) + if isinstance(bnd, Array) + else bnd == expr2.bindings[name] + for name, bnd in expr1.bindings.items()) + and self._map_tags(expr1.tags, expr2.tags) + ) + + def map_loopy_call_result(self, expr1: LoopyCallResult, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and self.rec(expr1._container, expr2._container) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + and expr1.name == expr2.name) + + def map_dict_of_named_arrays(self, expr1: DictOfNamedArrays, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and frozenset(expr1._data.keys()) == frozenset(expr2._data.keys()) + and all(self.rec(expr1._data[name], expr2._data[name]) + for name in expr1._data) + and self._map_tags(expr1.tags, expr2.tags) + ) + + def map_distributed_send_ref_holder( + self, expr1: DistributedSendRefHolder, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and self.rec(expr1.send.data, expr2.send.data) + and self.rec(expr1.passthrough_data, expr2.passthrough_data) + and expr1.send.dest_rank == expr2.send.dest_rank + and expr1.send.comm_tag == expr2.send.comm_tag + and expr1.send.tags == expr2.send.tags + and self._map_tags(expr1.tags, expr2.tags) + ) + + def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.src_rank == expr2.src_rank + and expr1.comm_tag == expr2.comm_tag + and len(expr1.shape) == len(expr2.shape) + and expr1.dtype == expr2.dtype + and self._map_tags(expr1.tags, expr2.tags) + ) + + def map_function_definition(self, expr1: FunctionDefinition, expr2: Any + ) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.parameters == expr2.parameters + and expr1.return_type == expr2.return_type + and (set(expr1.returns.keys()) == set(expr2.returns.keys())) + and all(self.rec(expr1.returns[k], expr2.returns[k]) + for k in expr1.returns) + and self._map_tags(expr1.tags, expr2.tags) + ) + + def map_call(self, expr1: Call, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and self.rec(expr1.function, expr2.function) + and frozenset(expr1.bindings) == frozenset(expr2.bindings) + and all(self.rec(bnd, + expr2.bindings[name]) + for name, bnd in expr1.bindings.items()) + and self._map_tags(expr1.tags, expr2.tags) + ) + + def map_named_call_result(self, expr1: NamedCallResult, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.name == expr2.name + and self.rec(expr1._container, expr2._container)) + +# }}} + # vim: fdm=marker diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 22ff78aad..93b77fdb5 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -50,6 +50,7 @@ TYPE_CHECKING, Any, Never, + Union, cast, ) @@ -84,6 +85,9 @@ # {{{ scalar expressions INT_CLASSES = (int, np.integer) +# IntegralScalarExpression = Union[IntegerT, prim.Expression] +Scalar = Union[np.number[Any], int, np.bool_, bool, float, complex] +ScalarExpression = Union[Scalar, prim.Expression] PYTHON_SCALAR_CLASSES = (int, float, complex, bool) SCALAR_CLASSES = prim.VALID_CONSTANT_CLASSES diff --git a/pytato/tags.py b/pytato/tags.py index e0a98b7da..e798a7c8f 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -17,6 +17,9 @@ .. autoclass:: FunctionIdentifier .. autoclass:: CallImplementationTag .. autoclass:: InlineCallTag +.. autoclass:: UseInputAxis +.. autoclass:: ConcatenatedCallInputConcatAxisTag +.. autoclass:: ConcatenatedCallOutputSliceAxisTag """ from dataclasses import dataclass @@ -233,3 +236,30 @@ class InlineCallTag(CallImplementationTag): A :class:`CallImplementationTag` that directs the :class:`pytato.target.Target` to inline the call site. """ + + +@dataclass(frozen=True) +class UseInputAxis(UniqueTag): + """ + A placeholder axis tag indicating that an array should derive tags from one of + its inputs. + """ + key: Hashable + axis: int + + +@dataclass(frozen=True) +class ConcatenatedCallInputConcatAxisTag(UniqueTag): + """ + An axis tag indicating that an array is a concatenation of multiple + inputs resulting from the transformations done in + :func:`pytato.concatenate_calls`. + """ + + +@dataclass(frozen=True) +class ConcatenatedCallOutputSliceAxisTag(UniqueTag): + """ + An axis tag indicating that an array is a slice of a concatenated output + resulting from the transformations done in :func:`pytato.concatenate_calls`. + """ diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index dc1045f25..9a1e1437b 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -43,6 +43,7 @@ import numpy as np from immutabledict import immutabledict +from orderedsets import FrozenOrderedSet, OrderedSet from typing_extensions import Self from pymbolic.mapper.optimize import optimize_mapper @@ -82,7 +83,7 @@ if TYPE_CHECKING: - from collections.abc import Callable, Iterable, Mapping + from collections.abc import Callable, Iterable, Mapping, Sequence ArrayOrNames: TypeAlias = Array | AbstractResultWithNamedArrays @@ -93,6 +94,7 @@ __doc__ = """ .. autoclass:: Mapper +.. autoclass:: CacheInputsWithKey .. autoclass:: CachedMapperCache .. autoclass:: CachedMapper .. autoclass:: TransformMapperCache @@ -100,6 +102,7 @@ .. autoclass:: TransformMapperWithExtraArgs .. autoclass:: CopyMapper .. autoclass:: CopyMapperWithExtraArgs +.. autoclass:: Deduplicator .. autoclass:: CombineMapper .. autoclass:: DependencyMapper .. autoclass:: InputGatherer @@ -114,6 +117,7 @@ .. autofunction:: map_and_copy .. autofunction:: materialize_with_mpms .. autofunction:: deduplicate_data_wrappers +.. autofunction:: unify_materialization_tags .. automodule:: pytato.transform.lower_to_index_lambda .. automodule:: pytato.transform.remove_broadcasts_einsum .. automodule:: pytato.transform.einsum_distributive_law @@ -187,6 +191,14 @@ class ForeignObjectError(ValueError): pass +class CacheCollisionError(ValueError): + pass + + +class MapperCreatedDuplicateError(ValueError): + pass + + # {{{ mapper base class ResultT = TypeVar("ResultT") @@ -299,17 +311,50 @@ def __call__( # {{{ CachedMapper -CacheExprT = TypeVar("CacheExprT") +CacheExprT = TypeVar("CacheExprT", bound=ArrayOrNames | FunctionDefinition) CacheResultT = TypeVar("CacheResultT") CacheKeyT: TypeAlias = Hashable -class CachedMapperCache(Generic[CacheExprT, CacheResultT]): +class CacheInputsWithKey(Generic[CacheExprT, P]): + """ + Data structure for inputs to :class:`CachedMapperCache`. + + .. attribute:: expr + + The input expression being mapped. + + .. attribute:: args + + A :class:`tuple` of extra positional arguments. + + .. attribute:: kwargs + + A :class:`dict` of extra keyword arguments. + + .. attribute:: key + + The cache key corresponding to *expr* and any additional inputs that were + passed. + + """ + def __init__( + self, + expr: CacheExprT, + key: CacheKeyT, + *args: P.args, + **kwargs: P.kwargs): + self.expr: CacheExprT = expr + self.args: tuple[Any, ...] = args + self.kwargs: dict[str, Any] = kwargs + self.key: CacheKeyT = key + + +class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]): """ Cache for mappers. .. automethod:: __init__ - .. method:: get_key Compute the key for an input expression. @@ -317,64 +362,51 @@ class CachedMapperCache(Generic[CacheExprT, CacheResultT]): .. automethod:: retrieve .. automethod:: clear """ - def __init__( - self, - key_func: Callable[..., CacheKeyT]) -> None: + def __init__(self, err_on_collision: bool) -> None: """ Initialize the cache. - :arg key_func: Function to compute a hashable cache key from an input - expression and any extra arguments. + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. """ - self.get_key = key_func + self.err_on_collision = err_on_collision - self._expr_key_to_result: dict[CacheKeyT, CacheResultT] = {} + self._input_key_to_result: dict[CacheKeyT, CacheResultT] = {} + if self.err_on_collision: + self._input_key_to_expr: dict[CacheKeyT, CacheExprT] = {} def add( self, - key_inputs: - CacheExprT - # Currently, Python's type system doesn't have a way to annotate - # containers of args/kwargs (ParamSpec won't work here). So we have - # to fall back to using Any. More details here: - # https://github.com/python/typing/issues/1252 - | tuple[CacheExprT, tuple[Any, ...], dict[str, Any]], - result: CacheResultT, - key: CacheKeyT | None = None) -> CacheResultT: + inputs: CacheInputsWithKey[CacheExprT, P], + result: CacheResultT) -> CacheResultT: """Cache a mapping result.""" - if key is None: - if isinstance(key_inputs, tuple): - expr, key_args, key_kwargs = key_inputs - key = self.get_key(expr, *key_args, **key_kwargs) - else: - key = self.get_key(key_inputs) + key = inputs.key - assert key not in self._expr_key_to_result, \ + assert key not in self._input_key_to_result, \ f"Cache entry is already present for key '{key}'." - self._expr_key_to_result[key] = result + self._input_key_to_result[key] = result + if self.err_on_collision: + self._input_key_to_expr[key] = inputs.expr return result - def retrieve( - self, - key_inputs: - CacheExprT - | tuple[CacheExprT, tuple[Any, ...], dict[str, Any]], - key: CacheKeyT | None = None) -> CacheResultT: + def retrieve(self, inputs: CacheInputsWithKey[CacheExprT, P]) -> CacheResultT: """Retrieve the cached mapping result.""" - if key is None: - if isinstance(key_inputs, tuple): - expr, key_args, key_kwargs = key_inputs - key = self.get_key(expr, *key_args, **key_kwargs) - else: - key = self.get_key(key_inputs) + key = inputs.key + + result = self._input_key_to_result[key] - return self._expr_key_to_result[key] + if self.err_on_collision and inputs.expr is not self._input_key_to_expr[key]: + raise CacheCollisionError + + return result def clear(self) -> None: """Reset the cache.""" - self._expr_key_to_result = {} + self._input_key_to_result = {} + if self.err_on_collision: + self._input_key_to_expr = {} class CachedMapper(Mapper[ResultT, FunctionResultT, P]): @@ -388,59 +420,110 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]): """ def __init__( self, + err_on_collision: bool = False, _cache: - CachedMapperCache[ArrayOrNames, ResultT] | None = None, + CachedMapperCache[ArrayOrNames, ResultT, P] | None = None, _function_cache: - CachedMapperCache[FunctionDefinition, FunctionResultT] | None = None + CachedMapperCache[FunctionDefinition, FunctionResultT, P] | None = None ) -> None: super().__init__() - self._cache: CachedMapperCache[ArrayOrNames, ResultT] = ( + self._cache: CachedMapperCache[ArrayOrNames, ResultT, P] = ( _cache if _cache is not None - else CachedMapperCache(self.get_cache_key)) + else CachedMapperCache(err_on_collision=err_on_collision)) self._function_cache: CachedMapperCache[ - FunctionDefinition, FunctionResultT] = ( + FunctionDefinition, FunctionResultT, P] = ( _function_cache if _function_cache is not None - else CachedMapperCache(self.get_function_definition_cache_key)) + else CachedMapperCache(err_on_collision=err_on_collision)) def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs - ) -> Hashable: - return (expr, *args, tuple(sorted(kwargs.items()))) + ) -> CacheKeyT: + if args or kwargs: + # Depending on whether extra arguments are passed by position or by + # keyword, they can end up in either args or kwargs; hence key is not + # uniquely defined in the general case + raise NotImplementedError( + "Derived classes must override get_cache_key if using extra inputs.") + return expr def get_function_definition_cache_key( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs - ) -> Hashable: - return (expr, *args, tuple(sorted(kwargs.items()))) + ) -> CacheKeyT: + if args or kwargs: + # Depending on whether extra arguments are passed by position or by + # keyword, they can end up in either args or kwargs; hence key is not + # uniquely defined in the general case + raise NotImplementedError( + "Derived classes must override get_function_definition_cache_key if " + "using extra inputs.") + return expr + + def _make_cache_inputs( + self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs + ) -> CacheInputsWithKey[ArrayOrNames, P]: + return CacheInputsWithKey( + expr, self.get_cache_key(expr, *args, **kwargs), *args, **kwargs) + + def _make_function_definition_cache_inputs( + self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs + ) -> CacheInputsWithKey[FunctionDefinition, P]: + return CacheInputsWithKey( + expr, self.get_function_definition_cache_key(expr, *args, **kwargs), + *args, **kwargs) + + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, P], + result: ResultT) -> ResultT: + return self._cache.add(inputs, result) + + def _function_cache_add( + self, + inputs: CacheInputsWithKey[FunctionDefinition, P], + result: FunctionResultT) -> FunctionResultT: + return self._function_cache.add(inputs, result) + + def _cache_retrieve(self, inputs: CacheInputsWithKey[ArrayOrNames, P]) -> ResultT: + try: + return self._cache.retrieve(inputs) + except CacheCollisionError as e: + raise ValueError( + f"cache collision detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_retrieve( + self, inputs: CacheInputsWithKey[FunctionDefinition, P]) -> FunctionResultT: + try: + return self._function_cache.retrieve(inputs) + except CacheCollisionError as e: + raise ValueError( + f"cache collision detected on {type(inputs.expr)} in " + f"{type(self)}.") from e def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: - key = self._cache.get_key(expr, *args, **kwargs) + inputs = self._make_cache_inputs(expr, *args, **kwargs) try: - return self._cache.retrieve((expr, args, kwargs), key=key) + return self._cache_retrieve(inputs) except KeyError: - return self._cache.add( - (expr, args, kwargs), - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - Mapper.rec(self, expr, *args, **kwargs), - key=key) + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + return self._cache_add(inputs, Mapper.rec(self, expr, *args, **kwargs)) def rec_function_definition( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs ) -> FunctionResultT: - key = self._function_cache.get_key(expr, *args, **kwargs) + inputs = self._make_function_definition_cache_inputs(expr, *args, **kwargs) try: - return self._function_cache.retrieve((expr, args, kwargs), key=key) + return self._function_cache_retrieve(inputs) except KeyError: - return self._function_cache.add( - (expr, args, kwargs), + return self._function_cache_add( # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, # see https://github.com/inducer/pytato/pull/585 - Mapper.rec_function_definition(self, expr, *args, **kwargs), - key=key) + inputs, Mapper.rec_function_definition(self, expr, *args, **kwargs)) def clone_for_callee( self, function: FunctionDefinition) -> Self: @@ -448,16 +531,96 @@ def clone_for_callee( Called to clone *self* before starting traversal of a :class:`pytato.function.FunctionDefinition`. """ - # Functions are cached globally, but arrays aren't - return type(self)(_function_cache=self._function_cache) + return type(self)( + err_on_collision=self._cache.err_on_collision, + # Functions are cached globally, but arrays aren't + _function_cache=self._function_cache) # }}} # {{{ TransformMapper -class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT]): - pass +def _is_mapper_created_duplicate(expr: CacheExprT, result: CacheExprT) -> bool: + """Returns *True* if *result* is not identical to *expr* when it ought to be.""" + from pytato.analysis import DirectPredecessorsGetter + pred_getter = DirectPredecessorsGetter(include_functions=True) + return ( + hash(result) == hash(expr) + and result == expr + and result is not expr + # Only consider "direct" duplication, not duplication resulting from + # equality-preserving changes to predecessors. Assume that such changes are + # OK, otherwise they would have been detected at the point at which they + # originated. (For example, consider a DAG containing pre-existing + # duplicates. If a subexpression of *expr* is a duplicate and is replaced + # with a previously encountered version from the cache, a new instance of + # *expr* must be created. This should not trigger an error.) + and all( + result_pred is pred + for pred, result_pred in zip( + # type-ignore-reason: mypy doesn't seem to recognize overloaded + # Mapper.__call__ here + pred_getter(expr), # type: ignore[arg-type] + pred_getter(result), # type: ignore[arg-type] + strict=True))) + + +class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, P]): + """ + Cache for :class:`TransformMapper` and :class:`TransformMapperWithExtraArgs`. + + .. automethod:: __init__ + .. automethod:: add + """ + def __init__( + self, + err_on_collision: bool, + err_on_created_duplicate: bool) -> None: + """ + Initialize the cache. + + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. + :arg err_on_created_duplicate: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + super().__init__(err_on_collision=err_on_collision) + + self.err_on_created_duplicate = err_on_created_duplicate + + self._result_to_cached_result: dict[CacheExprT, CacheExprT] = {} + + def add( + self, + inputs: CacheInputsWithKey[CacheExprT, P], + result: CacheExprT) -> CacheExprT: + """ + Cache a mapping result. + + Returns the cached result (which may not be identical to *result* if a + result was already cached with the same result key). + """ + key = inputs.key + + assert key not in self._input_key_to_result, \ + f"Cache entry is already present for key '{key}'." + + try: + result = self._result_to_cached_result[result] + except KeyError: + if ( + self.err_on_created_duplicate + and _is_mapper_created_duplicate(inputs.expr, result)): + raise MapperCreatedDuplicateError from None + + self._result_to_cached_result[result] = result + + self._input_key_to_result[key] = result + if self.err_on_collision: + self._input_key_to_expr[key] = inputs.expr + + return result class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): @@ -467,13 +630,71 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): Enables certain operations that can only be done if the mapping results are also arrays (e.g., computing a cache key from them). Does not implement default mapper methods; for that, see :class:`CopyMapper`. + + .. automethod:: __init__ + .. automethod:: clone_for_callee """ def __init__( self, - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + err_on_collision: bool = False, + err_on_created_duplicate: bool = False, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) + """ + :arg err_on_collision: Raise an exception if two distinct input array + instances have the same key. + :arg err_on_created_duplicate: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + if _cache is None: + _cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate) + + if _function_cache is None: + _function_cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate) + + super().__init__( + err_on_collision=err_on_collision, + _cache=_cache, + _function_cache=_function_cache) + + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, []], + result: ArrayOrNames) -> ArrayOrNames: + try: + return self._cache.add(inputs, result) + except MapperCreatedDuplicateError as e: + raise ValueError( + f"mapper-created duplicate detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_add( + self, + inputs: CacheInputsWithKey[FunctionDefinition, []], + result: FunctionDefinition) -> FunctionDefinition: + try: + return self._function_cache.add(inputs, result) + except MapperCreatedDuplicateError as e: + raise ValueError( + f"mapper-created duplicate detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + function_cache = cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache) + return type(self)( + err_on_collision=function_cache.err_on_collision, + err_on_created_duplicate=function_cache.err_on_created_duplicate, + _function_cache=function_cache) # }}} @@ -489,14 +710,72 @@ class TransformMapperWithExtraArgs( The logic in :class:`TransformMapper` purposely does not take the extra arguments to keep the cost of its each call frame low. + + .. automethod:: __init__ + .. automethod:: clone_for_callee """ def __init__( self, - _cache: TransformMapperCache[ArrayOrNames] | None = None, + err_on_collision: bool = False, + err_on_created_duplicate: bool = False, + _cache: TransformMapperCache[ArrayOrNames, P] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition] | None = None + TransformMapperCache[FunctionDefinition, P] | None = None ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) + """ + :arg err_on_collision: Raise an exception if two distinct input array + instances have the same key. + :arg err_on_created_duplicate: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + if _cache is None: + _cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate) + + if _function_cache is None: + _function_cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate) + + super().__init__( + err_on_collision=err_on_collision, + _cache=_cache, + _function_cache=_function_cache) + + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, P], + result: ArrayOrNames) -> ArrayOrNames: + try: + return self._cache.add(inputs, result) + except MapperCreatedDuplicateError as e: + raise ValueError( + f"mapper-created duplicate detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_add( + self, + inputs: CacheInputsWithKey[FunctionDefinition, P], + result: FunctionDefinition) -> FunctionDefinition: + try: + return self._function_cache.add(inputs, result) + except MapperCreatedDuplicateError as e: + raise ValueError( + f"mapper-created duplicate detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + function_cache = cast( + "TransformMapperCache[FunctionDefinition, P]", self._function_cache) + return type(self)( + err_on_collision=function_cache.err_on_collision, + err_on_created_duplicate=function_cache.err_on_created_duplicate, + _function_cache=function_cache) # }}} @@ -516,63 +795,104 @@ def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...] ) -> tuple[IndexOrShapeExpr, ...]: # type-ignore-reason: apparently mypy cannot substitute typevars # here. - return tuple(self.rec(s) if isinstance(s, Array) else s # type: ignore[misc] - for s in situp) + new_situp = tuple( + self.rec(s) if isinstance(s, Array) else s + for s in situp) + if all(new_s is s for s, new_s in zip(situp, new_situp, strict=True)): + return situp + else: + return new_situp # type: ignore[return-value] def map_index_lambda(self, expr: IndexLambda) -> Array: - bindings: Mapping[str, Array] = immutabledict({ + new_shape = self.rec_idx_or_size_tuple(expr.shape) + new_bindings: Mapping[str, Array] = immutabledict({ name: self.rec(subexpr) for name, subexpr in sorted(expr.bindings.items())}) - return IndexLambda(expr=expr.expr, - shape=self.rec_idx_or_size_tuple(expr.shape), - dtype=expr.dtype, - bindings=bindings, - axes=expr.axes, - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if ( + new_shape is expr.shape + and frozenset(new_bindings.keys()) == frozenset(expr.bindings.keys()) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings)): + return expr + else: + return IndexLambda(expr=expr.expr, + shape=new_shape, + dtype=expr.dtype, + bindings=new_bindings, + axes=expr.axes, + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_placeholder(self, expr: Placeholder) -> Array: assert expr.name is not None - return Placeholder(name=expr.name, - shape=self.rec_idx_or_size_tuple(expr.shape), - dtype=expr.dtype, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape) + if new_shape is expr.shape: + return expr + else: + return Placeholder(name=expr.name, + shape=new_shape, + dtype=expr.dtype, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack) -> Array: - arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) - return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + return expr + else: + return Stack(arrays=new_arrays, axis=expr.axis, axes=expr.axes, + tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_concatenate(self, expr: Concatenate) -> Array: - arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) - return Concatenate(arrays=arrays, axis=expr.axis, - axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + return expr + else: + return Concatenate(arrays=new_arrays, axis=expr.axis, + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_roll(self, expr: Roll) -> Array: - return Roll(array=_verify_is_array(self.rec(expr.array)), - shift=expr.shift, - axis=expr.axis, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array)) + if new_ary is expr.array: + return expr + else: + return Roll(array=new_ary, + shift=expr.shift, + axis=expr.axis, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation) -> Array: - return AxisPermutation(array=_verify_is_array(self.rec(expr.array)), - axis_permutation=expr.axis_permutation, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array)) + if new_ary is expr.array: + return expr + else: + return AxisPermutation(array=new_ary, + axis_permutation=expr.axis_permutation, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase) -> Array: - return type(expr)(_verify_is_array(self.rec(expr.array)), - indices=self.rec_idx_or_size_tuple(expr.indices), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array)) + new_indices = self.rec_idx_or_size_tuple(expr.indices) + if new_ary is expr.array and new_indices is expr.indices: + return expr + else: + return type(expr)(new_ary, + indices=new_indices, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_basic_index(self, expr: BasicIndex) -> Array: return self._map_index_base(expr) @@ -588,91 +908,131 @@ def map_non_contiguous_advanced_index(self, return self._map_index_base(expr) def map_data_wrapper(self, expr: DataWrapper) -> Array: - return DataWrapper( - data=expr.data, - shape=self.rec_idx_or_size_tuple(expr.shape), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape) + if new_shape is expr.shape: + return expr + else: + return DataWrapper( + data=expr.data, + shape=new_shape, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_size_param(self, expr: SizeParam) -> Array: assert expr.name is not None - return SizeParam( - name=expr.name, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + return expr def map_einsum(self, expr: Einsum) -> Array: - return Einsum(expr.access_descriptors, - tuple(_verify_is_array(self.rec(arg)) for arg in expr.args), - axes=expr.axes, - redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - def map_named_array(self, expr: NamedArray) -> Array: - container = self.rec(expr._container) - assert isinstance(container, AbstractResultWithNamedArrays) - return type(expr)(container, - expr.name, + new_args = tuple(_verify_is_array(self.rec(arg)) for arg in expr.args) + if all( + new_arg is arg + for arg, new_arg in zip(expr.args, new_args, strict=True)): + return expr + else: + return Einsum(expr.access_descriptors, + new_args, axes=expr.axes, + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, tags=expr.tags, non_equality_tags=expr.non_equality_tags) + def map_named_array(self, expr: NamedArray) -> Array: + new_container = self.rec(expr._container) + assert isinstance(new_container, AbstractResultWithNamedArrays) + if new_container is expr._container: + return expr + else: + return type(expr)(new_container, + expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> DictOfNamedArrays: - return DictOfNamedArrays({key: _verify_is_array(self.rec(val.expr)) - for key, val in expr.items()}, - tags=expr.tags - ) + new_data = { + key: _verify_is_array(self.rec(val.expr)) + for key, val in expr.items()} + if all( + new_data_val is val.expr + for val, new_data_val in zip( + expr.values(), + new_data.values(), + strict=True)): + return expr + else: + return DictOfNamedArrays(new_data, tags=expr.tags) def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: - bindings: Mapping[Any, Any] = immutabledict( + new_bindings: Mapping[Any, Any] = immutabledict( {name: (self.rec(subexpr) if isinstance(subexpr, Array) else subexpr) for name, subexpr in sorted(expr.bindings.items())}) - - return LoopyCall(translation_unit=expr.translation_unit, - bindings=bindings, - entrypoint=expr.entrypoint, - tags=expr.tags, - ) + if ( + frozenset(new_bindings.keys()) == frozenset(expr.bindings.keys()) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings)): + return expr + else: + return LoopyCall(translation_unit=expr.translation_unit, + bindings=new_bindings, + entrypoint=expr.entrypoint, + tags=expr.tags, + ) def map_loopy_call_result(self, expr: LoopyCallResult) -> Array: - rec_container = self.rec(expr._container) - assert isinstance(rec_container, LoopyCall) - return LoopyCallResult( - _container=rec_container, - name=expr.name, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_container = self.rec(expr._container) + assert isinstance(new_container, LoopyCall) + if new_container is expr._container: + return expr + else: + return LoopyCallResult( + _container=new_container, + name=expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_reshape(self, expr: Reshape) -> Array: - return Reshape(_verify_is_array(self.rec(expr.array)), - newshape=self.rec_idx_or_size_tuple(expr.newshape), - order=expr.order, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array)) + new_newshape = self.rec_idx_or_size_tuple(expr.newshape) + if new_ary is expr.array and new_newshape is expr.newshape: + return expr + else: + return Reshape(new_ary, + newshape=new_newshape, + order=expr.order, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_distributed_send_ref_holder( self, expr: DistributedSendRefHolder) -> Array: - return DistributedSendRefHolder( - send=DistributedSend( - data=_verify_is_array(self.rec(expr.send.data)), - dest_rank=expr.send.dest_rank, - comm_tag=expr.send.comm_tag), - passthrough_data=_verify_is_array(self.rec(expr.passthrough_data)), - ) + new_send_data = _verify_is_array(self.rec(expr.send.data)) + if new_send_data is expr.send.data: + new_send = expr.send + else: + new_send = DistributedSend( + data=new_send_data, + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag) + new_passthrough = _verify_is_array(self.rec(expr.passthrough_data)) + if new_send is expr.send and new_passthrough is expr.passthrough_data: + return expr + else: + return DistributedSendRefHolder(new_send, new_passthrough) def map_distributed_recv(self, expr: DistributedRecv) -> Array: - return DistributedRecv( - src_rank=expr.src_rank, comm_tag=expr.comm_tag, - shape=self.rec_idx_or_size_tuple(expr.shape), - dtype=expr.dtype, tags=expr.tags, axes=expr.axes, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape) + if new_shape is expr.shape: + return expr + else: + return DistributedRecv( + src_rank=expr.src_rank, comm_tag=expr.comm_tag, + shape=new_shape, dtype=expr.dtype, tags=expr.tags, + axes=expr.axes, non_equality_tags=expr.non_equality_tags) def map_function_definition(self, expr: FunctionDefinition) -> FunctionDefinition: @@ -681,19 +1041,37 @@ def map_function_definition(self, new_mapper = self.clone_for_callee(expr) new_returns = {name: new_mapper(ret) for name, ret in expr.returns.items()} - return dataclasses.replace(expr, returns=immutabledict(new_returns)) + if all( + new_ret is ret + for ret, new_ret in zip( + expr.returns.values(), + new_returns.values(), + strict=True)): + return expr + else: + return dataclasses.replace(expr, returns=immutabledict(new_returns)) def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: - return Call(self.rec_function_definition(expr.function), - immutabledict({name: self.rec(bnd) - for name, bnd in expr.bindings.items()}), - tags=expr.tags, - ) + new_function = self.rec_function_definition(expr.function) + new_bindings = { + name: _verify_is_array(self.rec(bnd)) + for name, bnd in expr.bindings.items()} + if ( + new_function is expr.function + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values(), + strict=True))): + return expr + else: + return Call(new_function, immutabledict(new_bindings), tags=expr.tags) def map_named_call_result(self, expr: NamedCallResult) -> Array: - call = self.rec(expr._container) - assert isinstance(call, Call) - return call[expr.name] + new_call = self.rec(expr._container) + assert isinstance(new_call, Call) + return new_call[expr.name] class CopyMapperWithExtraArgs(TransformMapperWithExtraArgs[P]): @@ -717,70 +1095,102 @@ def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...], def map_index_lambda(self, expr: IndexLambda, *args: P.args, **kwargs: P.kwargs) -> Array: - bindings: Mapping[str, Array] = immutabledict({ + new_shape = self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs) + new_bindings: Mapping[str, Array] = immutabledict({ name: self.rec(subexpr, *args, **kwargs) for name, subexpr in sorted(expr.bindings.items())}) - return IndexLambda(expr=expr.expr, - shape=self.rec_idx_or_size_tuple(expr.shape, - *args, **kwargs), - dtype=expr.dtype, - bindings=bindings, - axes=expr.axes, - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if ( + new_shape is expr.shape + and frozenset(new_bindings.keys()) == frozenset(expr.bindings.keys()) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings)): + return expr + else: + return IndexLambda(expr=expr.expr, + shape=new_shape, + dtype=expr.dtype, + bindings=new_bindings, + axes=expr.axes, + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_placeholder(self, expr: Placeholder, *args: P.args, **kwargs: P.kwargs) -> Array: assert expr.name is not None - return Placeholder(name=expr.name, - shape=self.rec_idx_or_size_tuple(expr.shape, - *args, **kwargs), - dtype=expr.dtype, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs) + if new_shape is expr.shape: + return expr + else: + return Placeholder(name=expr.name, + shape=new_shape, + dtype=expr.dtype, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack, *args: P.args, **kwargs: P.kwargs) -> Array: - arrays = tuple( + new_arrays: tuple[Array, ...] = tuple( _verify_is_array(self.rec(arr, *args, **kwargs)) for arr in expr.arrays) - return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + return expr + else: + return Stack(arrays=new_arrays, axis=expr.axis, axes=expr.axes, + tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_concatenate(self, expr: Concatenate, *args: P.args, **kwargs: P.kwargs) -> Array: - arrays = tuple( + new_arrays: tuple[Array, ...] = tuple( _verify_is_array(self.rec(arr, *args, **kwargs)) for arr in expr.arrays) - return Concatenate(arrays=arrays, axis=expr.axis, - axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + return expr + else: + return Concatenate(arrays=new_arrays, axis=expr.axis, + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_roll(self, expr: Roll, *args: P.args, **kwargs: P.kwargs) -> Array: - return Roll(array=_verify_is_array(self.rec(expr.array, *args, **kwargs)), - shift=expr.shift, - axis=expr.axis, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array, *args, **kwargs)) + if new_ary is expr.array: + return expr + else: + return Roll(array=new_ary, + shift=expr.shift, + axis=expr.axis, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation, *args: P.args, **kwargs: P.kwargs) -> Array: - return AxisPermutation(array=_verify_is_array( - self.rec(expr.array, *args, **kwargs)), - axis_permutation=expr.axis_permutation, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array, *args, **kwargs)) + if new_ary is expr.array: + return expr + else: + return AxisPermutation(array=new_ary, + axis_permutation=expr.axis_permutation, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase, *args: P.args, **kwargs: P.kwargs) -> Array: assert isinstance(expr, _SuppliedAxesAndTagsMixin) - return type(expr)(_verify_is_array(self.rec(expr.array, *args, **kwargs)), - indices=self.rec_idx_or_size_tuple(expr.indices, - *args, **kwargs), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array, *args, **kwargs)) + new_indices = self.rec_idx_or_size_tuple(expr.indices, *args, **kwargs) + if new_ary is expr.array and new_indices is expr.indices: + return expr + else: + return type(expr)(new_ary, + indices=new_indices, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_basic_index(self, expr: BasicIndex, *args: P.args, **kwargs: P.kwargs) -> Array: @@ -801,98 +1211,141 @@ def map_non_contiguous_advanced_index(self, def map_data_wrapper(self, expr: DataWrapper, *args: P.args, **kwargs: P.kwargs) -> Array: - return DataWrapper( - data=expr.data, - shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs) + if new_shape is expr.shape: + return expr + else: + return DataWrapper( + data=expr.data, + shape=new_shape, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_size_param(self, expr: SizeParam, *args: P.args, **kwargs: P.kwargs) -> Array: assert expr.name is not None - return SizeParam(name=expr.name, axes=expr.axes, tags=expr.tags) + return expr def map_einsum(self, expr: Einsum, *args: P.args, **kwargs: P.kwargs) -> Array: - return Einsum(expr.access_descriptors, - tuple(_verify_is_array( - self.rec(arg, *args, **kwargs)) for arg in expr.args), - axes=expr.axes, - redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - def map_named_array(self, - expr: NamedArray, *args: P.args, **kwargs: P.kwargs) -> Array: - container = self.rec(expr._container, *args, **kwargs) - assert isinstance(container, AbstractResultWithNamedArrays) - return type(expr)(container, - expr.name, + new_args: tuple[Array, ...] = tuple( + _verify_is_array(self.rec(arg, *args, **kwargs)) for arg in expr.args) + if all( + new_arg is arg + for arg, new_arg in zip(expr.args, new_args, strict=True)): + return expr + else: + return Einsum(expr.access_descriptors, + new_args, axes=expr.axes, + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, tags=expr.tags, non_equality_tags=expr.non_equality_tags) + def map_named_array(self, + expr: NamedArray, *args: P.args, **kwargs: P.kwargs) -> Array: + new_container = self.rec(expr._container, *args, **kwargs) + assert isinstance(new_container, AbstractResultWithNamedArrays) + if new_container is expr._container: + return expr + else: + return type(expr)(new_container, + expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays, *args: P.args, **kwargs: P.kwargs ) -> DictOfNamedArrays: - return DictOfNamedArrays({key: _verify_is_array( - self.rec(val.expr, *args, **kwargs)) - for key, val in expr.items()}, - tags=expr.tags, - ) + new_data = { + key: _verify_is_array(self.rec(val.expr, *args, **kwargs)) + for key, val in expr.items()} + if all( + new_data_val is val.expr + for val, new_data_val in zip( + expr.values(), + new_data.values(), + strict=True)): + return expr + else: + return DictOfNamedArrays(new_data, tags=expr.tags) def map_loopy_call(self, expr: LoopyCall, *args: P.args, **kwargs: P.kwargs) -> LoopyCall: - bindings: Mapping[Any, Any] = immutabledict( + new_bindings: Mapping[Any, Any] = immutabledict( {name: (self.rec(subexpr, *args, **kwargs) if isinstance(subexpr, Array) else subexpr) for name, subexpr in sorted(expr.bindings.items())}) - - return LoopyCall(translation_unit=expr.translation_unit, - bindings=bindings, - entrypoint=expr.entrypoint, - tags=expr.tags, - ) + if ( + frozenset(new_bindings.keys()) == frozenset(expr.bindings.keys()) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings)): + return expr + else: + return LoopyCall(translation_unit=expr.translation_unit, + bindings=new_bindings, + entrypoint=expr.entrypoint, + tags=expr.tags, + ) def map_loopy_call_result(self, expr: LoopyCallResult, *args: P.args, **kwargs: P.kwargs) -> Array: - rec_loopy_call = self.rec(expr._container, *args, **kwargs) - assert isinstance(rec_loopy_call, LoopyCall) - return LoopyCallResult( - _container=rec_loopy_call, - name=expr.name, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_container = self.rec(expr._container, *args, **kwargs) + assert isinstance(new_container, LoopyCall) + if new_container is expr._container: + return expr + else: + return LoopyCallResult( + _container=new_container, + name=expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_reshape(self, expr: Reshape, *args: P.args, **kwargs: P.kwargs) -> Array: - return Reshape(_verify_is_array(self.rec(expr.array, *args, **kwargs)), - newshape=self.rec_idx_or_size_tuple(expr.newshape, - *args, **kwargs), - order=expr.order, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array, *args, **kwargs)) + new_newshape = self.rec_idx_or_size_tuple(expr.newshape, *args, **kwargs) + if new_ary is expr.array and new_newshape is expr.newshape: + return expr + else: + return Reshape(new_ary, + newshape=new_newshape, + order=expr.order, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder, *args: P.args, **kwargs: P.kwargs) -> Array: - return DistributedSendRefHolder( - send=DistributedSend( - data=_verify_is_array(self.rec(expr.send.data, *args, **kwargs)), - dest_rank=expr.send.dest_rank, - comm_tag=expr.send.comm_tag), - passthrough_data=_verify_is_array( - self.rec(expr.passthrough_data, *args, **kwargs))) + new_send_data = _verify_is_array(self.rec(expr.send.data, *args, **kwargs)) + if new_send_data is expr.send.data: + new_send = expr.send + else: + new_send = DistributedSend( + data=new_send_data, + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag) + new_passthrough = _verify_is_array( + self.rec(expr.passthrough_data, *args, **kwargs)) + if new_send is expr.send and new_passthrough is expr.passthrough_data: + return expr + else: + return DistributedSendRefHolder(new_send, new_passthrough) def map_distributed_recv(self, expr: DistributedRecv, *args: P.args, **kwargs: P.kwargs) -> Array: - return DistributedRecv( - src_rank=expr.src_rank, comm_tag=expr.comm_tag, - shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs), - dtype=expr.dtype, tags=expr.tags, axes=expr.axes, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs) + if new_shape is expr.shape: + return expr + else: + return DistributedRecv( + src_rank=expr.src_rank, comm_tag=expr.comm_tag, + shape=new_shape, dtype=expr.dtype, tags=expr.tags, + axes=expr.axes, non_equality_tags=expr.non_equality_tags) def map_function_definition( self, expr: FunctionDefinition, @@ -904,17 +1357,49 @@ def map_function_definition( def map_call(self, expr: Call, *args: P.args, **kwargs: P.kwargs) -> AbstractResultWithNamedArrays: - return Call(self.rec_function_definition(expr.function, *args, **kwargs), - immutabledict({name: self.rec(bnd, *args, **kwargs) - for name, bnd in expr.bindings.items()}), - tags=expr.tags, - ) + new_function = self.rec_function_definition(expr.function, *args, **kwargs) + new_bindings = { + name: self.rec(bnd, *args, **kwargs) + for name, bnd in expr.bindings.items()} + if ( + new_function is expr.function + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values(), + strict=True))): + return expr + else: + return Call(new_function, immutabledict(new_bindings), tags=expr.tags) def map_named_call_result(self, expr: NamedCallResult, *args: P.args, **kwargs: P.kwargs) -> Array: - call = self.rec(expr._container, *args, **kwargs) - assert isinstance(call, Call) - return call[expr.name] + new_call = self.rec(expr._container, *args, **kwargs) + assert isinstance(new_call, Call) + return new_call[expr.name] + +# }}} + + +# {{{ Deduplicator + +class Deduplicator(CopyMapper): + """Removes duplicate nodes from an expression.""" + def __init__( + self, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + super().__init__( + err_on_collision=False, err_on_created_duplicate=False, + _cache=_cache, + _function_cache=_function_cache) + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + _function_cache=cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) # }}} @@ -953,6 +1438,9 @@ def map_placeholder(self, expr: Placeholder) -> ResultT: def map_data_wrapper(self, expr: DataWrapper) -> ResultT: return self.combine(*self.rec_idx_or_size_tuple(expr.shape)) + def map_size_param(self, expr: SizeParam) -> ResultT: + return self.combine(cast("ResultT", frozenset({}))) + def map_stack(self, expr: Stack) -> ResultT: return self.combine(*(self.rec(ary) for ary in expr.arrays)) @@ -1036,7 +1524,8 @@ def map_named_call_result(self, expr: NamedCallResult) -> ResultT: # {{{ DependencyMapper -class DependencyMapper(CombineMapper[R, R]): +# FIXME: Change to ordered sets (including R) +class DependencyMapper(CombineMapper[R, Never]): """ Maps a :class:`pytato.array.Array` to a :class:`frozenset` of :class:`pytato.array.Array`'s it depends on. @@ -1098,23 +1587,23 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> R: return self.combine(frozenset([expr]), super().map_distributed_recv(expr)) - def map_function_definition(self, expr: FunctionDefinition) -> R: + def map_call(self, expr: Call) -> R: # do not include arrays from the function's body as it would involve # putting arrays from different namespaces into the same collection. - return frozenset() - - def map_call(self, expr: Call) -> R: - return self.combine(self.rec_function_definition(expr.function), - *[self.rec(bnd) for bnd in expr.bindings.values()]) + return self.combine(*[self.rec(bnd) for bnd in expr.bindings.values()]) def map_named_call_result(self, expr: NamedCallResult) -> R: return self.rec(expr._container) + def clone_for_callee(self, function: FunctionDefinition) -> Self: + raise AssertionError("Control shouldn't reach this point.") + # }}} # {{{ SubsetDependencyMapper +# FIXME: Change to ordered sets class SubsetDependencyMapper(DependencyMapper): """ Mapper to combine the dependencies of an expression that are a subset of @@ -1135,6 +1624,7 @@ def combine(self, *args: frozenset[Array]) -> frozenset[Array]: # {{{ InputGatherer +# FIXME: Change to ordered sets class InputGatherer( CombineMapper[frozenset[InputArgumentBase], frozenset[InputArgumentBase]]): """ @@ -1184,8 +1674,140 @@ def map_call(self, expr: Call) -> frozenset[InputArgumentBase]: # }}} +# {{{ precompute_subexpressions + +# FIXME: Change to ordered sets +# FIXME: Think about what happens when subexpressions contain outlined functions +class _PrecomputableSubexpressionGatherer( + CombineMapper[frozenset[Array], frozenset[Array]]): + """ + Mapper to find subexpressions that do not depend on any placeholders. + """ + def rec(self, expr: ArrayOrNames) -> frozenset[Array]: + inputs = self._make_cache_inputs(expr) + try: + return self._cache_retrieve(inputs) + except KeyError: + result: frozenset[Array] = Mapper.rec(self, expr) + if not isinstance(expr, + Placeholder + | DictOfNamedArrays + | Call): + assert isinstance(expr, Array) + from pytato.analysis import DirectPredecessorsGetter + if result == DirectPredecessorsGetter()(expr): + result = frozenset({expr}) + return self._cache_add(inputs, result) + + # type-ignore reason: incompatible ret. type with super class + def __call__(self, expr: ArrayOrNames) -> frozenset[Array]: # type: ignore + subexprs = self.rec(expr) + + # Need to treat data arrays as precomputable during recursion, but afterwards + # we only care about larger expressions containing them *or* their shape if + # it's a non-constant expression + # FIXME: Does it even make sense for a data array to have an expression as + # a shape? Maybe this isn't necessary... + + data_subexprs = { + ary + for ary in subexprs + if isinstance(ary, DataWrapper | DistributedRecv)} + + subexprs -= data_subexprs + + for ary in data_subexprs: + subexprs |= self.combine(*self.rec_idx_or_size_tuple(ary.shape)) + + return subexprs + + def combine(self, *args: frozenset[Array]) -> frozenset[Array]: + from functools import reduce + return reduce(lambda a, b: a | b, args, frozenset()) + + def map_function_definition(self, expr: FunctionDefinition) -> frozenset[Array]: + # FIXME: Ignoring subexpressions inside function definitions for now + return frozenset() + + def map_call(self, expr: Call) -> frozenset[Array]: + rec_fn = self.rec_function_definition(expr.function) + assert not rec_fn + rec_bindings: Mapping[str, frozenset[Array]] = immutabledict({ + name: self.rec(bnd) if isinstance(bnd, Array) else frozenset({bnd}) + for name, bnd in expr.bindings.items()}) + if all( + rec_bindings[name] == frozenset({expr.bindings[name]}) + for name in expr.bindings): + # FIXME: This conflicts with type annotations + return frozenset({expr}) + else: + return self.combine(rec_fn, *rec_bindings.values()) + + +class _PrecomputableSubexpressionReplacer(CopyMapper): + """ + Mapper to replace precomputable subexpressions found by + :class:`_PrecomputableSubexpressionGatherer` with the evaluated versions. + """ + def __init__( + self, + replacement_map: Mapping[Array, Array], + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: + TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + super().__init__(_cache=_cache, _function_cache=_function_cache) + self.replacement_map = replacement_map + + def rec(self, expr: ArrayOrNames) -> ArrayOrNames: + inputs = self._make_cache_inputs(expr) + try: + return self._cache_retrieve(inputs) + except KeyError: + result: ArrayOrNames | None = None + if isinstance(expr, Array): + try: + result = self.replacement_map[expr] + except KeyError: + pass + result = self.rec(result) if result is not None else Mapper.rec(self, expr) + return self._cache_add(inputs, result) + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + {}, + _function_cache=cast( + "TransformMapperCache[FunctionDefinition, []]",self._function_cache)) + + +def precompute_subexpressions( + expr: ArrayOrNames, + # FIXME: Don't use Sequence for this + eval_func: Callable[[Sequence[ArrayOrNames]], Sequence[ArrayOrNames]] + ) -> ArrayOrNames: + """Evaluate subexpressions in *expr* that do not depend on any placeholders.""" + precomputable_subexprs = _PrecomputableSubexpressionGatherer()(expr) + for subexpr in precomputable_subexprs: + from pytato.analysis import get_num_nodes + nnodes = get_num_nodes(subexpr) + if nnodes > 1: + print( + "Found precomputable subexpression of type " + f"{type(subexpr).__name__} with {nnodes} nodes.") + from pytools.obj_array import make_obj_array + # FIXME: Don't use object array + precomputable_subexprs_ary = make_obj_array(list(precomputable_subexprs)) + evaled_subexprs_ary = eval_func(precomputable_subexprs_ary) + subexpr_to_evaled_subexpr = dict( + zip(precomputable_subexprs_ary, evaled_subexprs_ary, strict=True)) + return _PrecomputableSubexpressionReplacer(subexpr_to_evaled_subexpr)(expr) + +# }}} + + # {{{ SizeParamGatherer +# FIXME: Change to ordered sets class SizeParamGatherer( CombineMapper[frozenset[SizeParam], frozenset[SizeParam]]): """ @@ -1437,13 +2059,13 @@ class CachedWalkMapper(WalkMapper[P]): def __init__( self, - _visited_functions: set[VisitKeyT] | None = None + _visited_functions: OrderedSet[VisitKeyT] | None = None ) -> None: super().__init__() - self._visited_arrays_or_names: set[VisitKeyT] = set() + self._visited_arrays_or_names: OrderedSet[VisitKeyT] = OrderedSet() - self._visited_functions: set[VisitKeyT] = \ - _visited_functions if _visited_functions is not None else set() + self._visited_functions: OrderedSet[VisitKeyT] = \ + _visited_functions if _visited_functions is not None else OrderedSet() def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs @@ -1494,7 +2116,7 @@ class TopoSortMapper(CachedWalkMapper[[]]): def __init__( self, - _visited_functions: set[VisitKeyT] | None = None) -> None: + _visited_functions: OrderedSet[VisitKeyT] | None = None) -> None: super().__init__(_visited_functions=_visited_functions) self.topological_order: list[Array] = [] @@ -1521,9 +2143,10 @@ class CachedMapAndCopyMapper(CopyMapper): def __init__( self, + # FIXME: Should map_fn be applied to functions too? map_fn: Callable[[ArrayOrNames], ArrayOrNames], - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn @@ -1533,18 +2156,17 @@ def clone_for_callee( return type(self)( self.map_fn, _function_cache=cast( - "TransformMapperCache[FunctionDefinition]", self._function_cache)) + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache_retrieve(inputs) except KeyError: - return self._cache.add( - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - expr, Mapper.rec(self, self.map_fn(expr)), key=key) + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + return self._cache_add(inputs, Mapper.rec(self, self.map_fn(expr))) # }}} @@ -1561,6 +2183,65 @@ class MPMSMaterializerAccumulator: expr: Array +class MPMSMaterializerCache( + CachedMapperCache[ArrayOrNames, MPMSMaterializerAccumulator, []]): + """ + Cache for :class:`MPMSMaterializer`. + + .. automethod:: __init__ + .. automethod:: add + """ + def __init__( + self, + err_on_collision: bool, + err_on_created_duplicate: bool) -> None: + """ + Initialize the cache. + + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. + :arg err_on_created_duplicate: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + super().__init__(err_on_collision=err_on_collision) + + self.err_on_created_duplicate = err_on_created_duplicate + + self._result_key_to_result: dict[ + ArrayOrNames, MPMSMaterializerAccumulator] = {} + + def add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, []], + result: MPMSMaterializerAccumulator) -> MPMSMaterializerAccumulator: + """ + Cache a mapping result. + + Returns the cached result (which may not be identical to *result* if a + result was already cached with the same result key). + """ + key = inputs.key + + assert key not in self._input_key_to_result, \ + f"Cache entry is already present for key '{key}'." + + try: + result = self._result_key_to_result[result.expr] + except KeyError: + if ( + self.err_on_created_duplicate + and _is_mapper_created_duplicate(inputs.expr, result.expr)): + raise MapperCreatedDuplicateError from None + + self._result_key_to_result[result.expr] = result + + self._input_key_to_result[key] = result + if self.err_on_collision: + self._input_key_to_expr[key] = inputs.expr + + return result + + def _materialize_if_mpms(expr: Array, nsuccessors: int, predecessors: Iterable[MPMSMaterializerAccumulator] @@ -1572,19 +2253,22 @@ def _materialize_if_mpms(expr: Array, """ from functools import reduce - materialized_predecessors: frozenset[Array] = reduce( - frozenset.union, + materialized_predecessors: FrozenOrderedSet[Array] = reduce( + FrozenOrderedSet.union, (pred.materialized_predecessors for pred in predecessors), - frozenset()) + FrozenOrderedSet()) if nsuccessors > 1 and len(materialized_predecessors) > 1: - new_expr = expr.tagged(ImplStored()) - return MPMSMaterializerAccumulator(frozenset([new_expr]), new_expr) + if not expr.tags_of_type(ImplStored): + new_expr = expr.tagged(ImplStored()) + else: + new_expr = expr + return MPMSMaterializerAccumulator(FrozenOrderedSet([new_expr]), new_expr) else: return MPMSMaterializerAccumulator(materialized_predecessors, expr) -class MPMSMaterializer(Mapper[MPMSMaterializerAccumulator, Never, []]): +class MPMSMaterializer(CachedMapper[MPMSMaterializerAccumulator, Never, []]): """ See :func:`materialize_with_mpms` for an explanation. @@ -1593,21 +2277,45 @@ class MPMSMaterializer(Mapper[MPMSMaterializerAccumulator, Never, []]): A mapping from a node in the expression graph (i.e. an :class:`~pytato.Array`) to its number of successors. """ - def __init__(self, nsuccessors: Mapping[Array, int]): - super().__init__() + def __init__( + self, + nsuccessors: Mapping[Array, int], + _cache: MPMSMaterializerCache | None = None): + err_on_collision = False + err_on_created_duplicate = False + + if _cache is None: + _cache = MPMSMaterializerCache( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate) + + # Does not support functions, so function_cache is ignored + super().__init__(err_on_collision=err_on_collision, _cache=_cache) + self.nsuccessors = nsuccessors - self.cache: dict[ArrayOrNames, MPMSMaterializerAccumulator] = {} - def rec(self, expr: ArrayOrNames) -> MPMSMaterializerAccumulator: - if expr in self.cache: - return self.cache[expr] - result: MPMSMaterializerAccumulator = super().rec(expr) - self.cache[expr] = result - return result + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, []], + result: MPMSMaterializerAccumulator) -> MPMSMaterializerAccumulator: + try: + return self._cache.add(inputs, result) + except MapperCreatedDuplicateError as e: + raise ValueError( + f"no-op duplication detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def clone_for_callee( + self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + raise AssertionError("Control shouldn't reach this point.") def _map_input_base(self, expr: InputArgumentBase ) -> MPMSMaterializerAccumulator: - return MPMSMaterializerAccumulator(frozenset([expr]), expr) + return MPMSMaterializerAccumulator(FrozenOrderedSet([expr]), expr) map_placeholder = _map_input_base map_data_wrapper = _map_input_base @@ -1620,24 +2328,42 @@ def map_named_array(self, expr: NamedArray) -> MPMSMaterializerAccumulator: def map_index_lambda(self, expr: IndexLambda) -> MPMSMaterializerAccumulator: children_rec = {bnd_name: self.rec(bnd) for bnd_name, bnd in sorted(expr.bindings.items())} + new_children: Mapping[str, Array] = immutabledict({ + bnd_name: bnd.expr + for bnd_name, bnd in sorted(children_rec.items())}) + + if ( + ( + FrozenOrderedSet(new_children.keys()) + == FrozenOrderedSet(expr.bindings.keys())) + and all( + new_children[name] is expr.bindings[name] + for name in expr.bindings)): + new_expr = expr + else: + new_expr = IndexLambda( + expr=expr.expr, + shape=expr.shape, + dtype=expr.dtype, + bindings=new_children, + axes=expr.axes, + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) - new_expr = IndexLambda(expr=expr.expr, - shape=expr.shape, - dtype=expr.dtype, - bindings=immutabledict({bnd_name: bnd.expr - for bnd_name, bnd in sorted(children_rec.items())}), - axes=expr.axes, - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], children_rec.values()) def map_stack(self, expr: Stack) -> MPMSMaterializerAccumulator: rec_arrays = [self.rec(ary) for ary in expr.arrays] - new_expr = Stack(tuple(ary.expr for ary in rec_arrays), - expr.axis, axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_arrays = tuple(ary.expr for ary in rec_arrays) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + new_expr = expr + else: + new_expr = Stack(new_arrays, expr.axis, axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -1645,29 +2371,44 @@ def map_stack(self, expr: Stack) -> MPMSMaterializerAccumulator: def map_concatenate(self, expr: Concatenate) -> MPMSMaterializerAccumulator: rec_arrays = [self.rec(ary) for ary in expr.arrays] - new_expr = Concatenate(tuple(ary.expr for ary in rec_arrays), - expr.axis, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_arrays = tuple(ary.expr for ary in rec_arrays) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + new_expr = expr + else: + new_expr = Concatenate(new_arrays, + expr.axis, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + return _materialize_if_mpms(new_expr, self.nsuccessors[expr], rec_arrays) def map_roll(self, expr: Roll) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) - new_expr = Roll(rec_array.expr, expr.shift, expr.axis, axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if rec_array.expr is expr.array: + new_expr = expr + else: + new_expr = Roll(rec_array.expr, expr.shift, expr.axis, axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) def map_axis_permutation(self, expr: AxisPermutation ) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) - new_expr = AxisPermutation(rec_array.expr, expr.axis_permutation, - axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if rec_array.expr is expr.array: + new_expr = expr + else: + new_expr = AxisPermutation(rec_array.expr, expr.axis_permutation, + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) @@ -1677,16 +2418,23 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: rec_indices = {i: self.rec(idx) for i, idx in enumerate(expr.indices) if isinstance(idx, Array)} - - new_expr = type(expr)(rec_array.expr, - tuple(rec_indices[i].expr - if i in rec_indices - else expr.indices[i] - for i in range( - len(expr.indices))), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_indices = tuple(rec_indices[i].expr + if i in rec_indices + else expr.indices[i] + for i in range( + len(expr.indices))) + if ( + rec_array.expr is expr.array + and all( + new_idx is idx + for idx, new_idx in zip(expr.indices, new_indices, strict=True))): + new_expr = expr + else: + new_expr = type(expr)(rec_array.expr, + new_indices, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -1699,26 +2447,35 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: def map_reshape(self, expr: Reshape) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) - new_expr = Reshape(rec_array.expr, expr.newshape, - expr.order, axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if rec_array.expr is expr.array: + new_expr = expr + else: + new_expr = Reshape(rec_array.expr, expr.newshape, + expr.order, axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) def map_einsum(self, expr: Einsum) -> MPMSMaterializerAccumulator: - rec_arrays = [self.rec(ary) for ary in expr.args] - new_expr = Einsum(expr.access_descriptors, - tuple(ary.expr for ary in rec_arrays), - expr.redn_axis_to_redn_descr, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + rec_args = [self.rec(ary) for ary in expr.args] + new_args = tuple(ary.expr for ary in rec_args) + if all( + new_arg is arg + for arg, new_arg in zip(expr.args, new_args, strict=True)): + new_expr = expr + else: + new_expr = Einsum(expr.access_descriptors, + new_args, + expr.redn_axis_to_redn_descr, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], - rec_arrays) + rec_args) def map_dict_of_named_arrays(self, expr: DictOfNamedArrays ) -> MPMSMaterializerAccumulator: @@ -1726,26 +2483,32 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays def map_loopy_call_result(self, expr: NamedArray) -> MPMSMaterializerAccumulator: # loopy call result is always materialized - return MPMSMaterializerAccumulator(frozenset([expr]), expr) + return MPMSMaterializerAccumulator(FrozenOrderedSet([expr]), expr) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder ) -> MPMSMaterializerAccumulator: - rec_passthrough = self.rec(expr.passthrough_data) rec_send_data = self.rec(expr.send.data) - new_expr = DistributedSendRefHolder( - send=DistributedSend(rec_send_data.expr, - dest_rank=expr.send.dest_rank, - comm_tag=expr.send.comm_tag, - tags=expr.send.tags), - passthrough_data=rec_passthrough.expr, - ) + if rec_send_data.expr is expr.send.data: + new_send = expr.send + else: + new_send = DistributedSend( + rec_send_data.expr, + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag, + tags=expr.send.tags) + rec_passthrough = self.rec(expr.passthrough_data) + if new_send is expr.send and rec_passthrough.expr is expr.passthrough_data: + new_expr = expr + else: + new_expr = DistributedSendRefHolder(new_send, rec_passthrough.expr) + return MPMSMaterializerAccumulator( rec_passthrough.materialized_predecessors, new_expr) def map_distributed_recv(self, expr: DistributedRecv ) -> MPMSMaterializerAccumulator: - return MPMSMaterializerAccumulator(frozenset([expr]), expr) + return MPMSMaterializerAccumulator(FrozenOrderedSet([expr]), expr) def map_named_call_result(self, expr: NamedCallResult ) -> MPMSMaterializerAccumulator: @@ -1775,6 +2538,7 @@ def copy_dict_of_named_arrays(source_dict: DictOfNamedArrays, return DictOfNamedArrays(data, tags=source_dict.tags) +# FIXME: Use ordered sets def get_dependencies(expr: DictOfNamedArrays) -> dict[str, frozenset[Array]]: """Returns the dependencies of each named array in *expr*. """ @@ -1869,6 +2633,7 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: # {{{ UsersCollector +# FIXME: Use ordered sets class UsersCollector(CachedMapper[None, Never, []]): """ Maps a graph to a dictionary representation mapping a node to its users, @@ -2010,6 +2775,7 @@ def map_named_call_result(self, expr: NamedCallResult) -> None: self.rec(expr._container) +# FIXME: Use ordered sets def get_users(expr: ArrayOrNames) -> dict[ArrayOrNames, set[ArrayOrNames]]: """ @@ -2024,6 +2790,7 @@ def get_users(expr: ArrayOrNames) -> dict[ArrayOrNames, # {{{ operations on graphs in dict form +# FIXME: Use ordered sets def _recursively_get_all_users( direct_users: Mapping[ArrayOrNames, set[ArrayOrNames]], node: ArrayOrNames) -> frozenset[ArrayOrNames]: @@ -2048,6 +2815,7 @@ def _recursively_get_all_users( return frozenset(result) +# FIXME: Use ordered sets def rec_get_user_nodes(expr: ArrayOrNames, node: ArrayOrNames, ) -> frozenset[ArrayOrNames]: @@ -2069,8 +2837,8 @@ class DataWrapperDeduplicator(CopyMapper): """ def __init__( self, - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.data_wrapper_cache: dict[CacheKeyT, DataWrapper] = {} @@ -2124,6 +2892,11 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: self.data_wrapper_cache[cache_key] = expr return expr + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + _function_cache=cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) + def deduplicate_data_wrappers(array_or_names: ArrayOrNames) -> ArrayOrNames: """For the expression graph given as *array_or_names*, replace all @@ -2159,4 +2932,32 @@ def deduplicate_data_wrappers(array_or_names: ArrayOrNames) -> ArrayOrNames: # }}} + +# {{{ unify_materialization_tags + +def unify_materialization_tags(array_or_names: ArrayOrNames) -> ArrayOrNames: + """ + For the expression graph given as *array_or_names*, replace all + non-materialized subexpressions with the corresponding materialized version if + one exists elsewhere in the DAG. + """ + from pytato.analysis import collect_materialized_nodes + materialized_exprs = collect_materialized_nodes(array_or_names) + + non_materialized_expr_to_materialized_expr = { + expr.without_tags(ImplStored()): expr + for expr in materialized_exprs} + + def unify(expr): + if expr.tags_of_type(ImplStored): + return expr + try: + return non_materialized_expr_to_materialized_expr[expr] + except KeyError: + return expr + + return map_and_copy(array_or_names, unify) + +# }}} + # vim: foldmethod=marker diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index 34f89cbc1..93f71760f 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -1,11 +1,16 @@ -""" +from __future__ import annotations + + +__doc__ = """ .. currentmodule:: pytato.transform.calls .. autofunction:: inline_calls +.. autofunction:: concatenate_calls .. autofunction:: tag_all_calls_to_be_inlined -""" -from __future__ import annotations +.. autofunction:: zero_unused_call_bindings +.. autoclass:: CallSiteLocation +""" __copyright__ = "Copyright (C) 2022 Kaushik Kulkarni" @@ -29,22 +34,91 @@ THE SOFTWARE. """ +import itertools +import logging +import numpy as np +from functools import partialmethod, reduce +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Collection, + Dict, + Generator, + List, + Never, + Sequence, + Tuple, + cast, +) +from typing_extensions import Self + +import attrs +from immutabledict import immutabledict +from orderedsets import FrozenOrderedSet, OrderedSet -from typing import TYPE_CHECKING +import pymbolic.primitives as prim +from pytools import memoize_method, memoize_on_first_arg +import pytato.scalar_expr as scalar_expr +from pytato.analysis import collect_nodes_of_type from pytato.array import ( AbstractResultWithNamedArrays, Array, + AxisPermutation, + BasicIndex, + Concatenate, + DataWrapper, DictOfNamedArrays, + Einsum, + IndexBase, + IndexLambda, + InputArgumentBase, Placeholder, + Reshape, + Roll, + ShapeComponent, + ShapeType, + SizeParam, + Stack, + concatenate, + zeros, ) -from pytato.function import Call, NamedCallResult -from pytato.tags import InlineCallTag -from pytato.transform import ArrayOrNames, CopyMapper, _verify_is_array +from pytato.function import Call, FunctionDefinition, NamedCallResult if TYPE_CHECKING: from collections.abc import Mapping +from pytato.tags import ( + ConcatenatedCallInputConcatAxisTag, + ConcatenatedCallOutputSliceAxisTag, + FunctionIdentifier, + ImplStored, + InlineCallTag, + UseInputAxis, +) +from pytato.transform import ( + ArrayOrNames, + CachedMapper, + CachedWalkMapper, + CombineMapper, + CopyMapper, + Deduplicator, + InputGatherer, + TransformMapperCache, + TransformMapperWithExtraArgs, + _verify_is_array, +) +from pytato.transform.lower_to_index_lambda import to_index_lambda +from pytato.utils import are_shape_components_equal + + +if TYPE_CHECKING: + from pytato.loopy import LoopyCallResult + +logger = logging.getLogger(__name__) + +ArrayOnStackT = Tuple[Tuple[Call, ...], Array] # {{{ inlining @@ -55,6 +129,12 @@ class PlaceholderSubstitutor(CopyMapper): A mapping from the placeholder name to the array that it is to be substituted with. + + .. note:: + + This mapper does not deduplicate subexpressions that occur in both the mapped + expression and the substitutions. Must follow up with a + :class:`pytato.transform.Deduplicator` if duplicates need to be removed. """ def __init__(self, substitutions: Mapping[str, Array]) -> None: @@ -63,32 +143,51 @@ def __init__(self, substitutions: Mapping[str, Array]) -> None: self.substitutions = substitutions def map_placeholder(self, expr: Placeholder) -> Array: + # Can't call rec() to remove duplicates here, because the substituted-in + # expression may potentially contain unrelated placeholders whose names + # collide with the ones being replaced return self.substitutions[expr.name] - def map_named_call_result(self, expr: NamedCallResult) -> NamedCallResult: - raise NotImplementedError( - "PlaceholderSubstitutor does not support functions.") + def map_function_definition( + self, expr: FunctionDefinition) -> FunctionDefinition: + # Only operates within the current stack frame + return expr class Inliner(CopyMapper): """ Primary mapper for :func:`inline_calls`. """ - def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: - # inline call sites within the callee. - new_expr = super().map_call(expr) - assert isinstance(new_expr, Call) + def __init__( + self, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + # Must disable collision/duplication checking because we're combining + # expressions that were previously in two different call stack frames + # (and were thus cached separately) + super().__init__( + err_on_collision=False, + err_on_created_duplicate=False, + _cache=_cache, + _function_cache=_function_cache) + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + _function_cache=cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) + + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: if expr.tags_of_type(InlineCallTag): - substitutor = PlaceholderSubstitutor(new_expr.bindings) + substitutor = PlaceholderSubstitutor(expr.bindings) return DictOfNamedArrays( - {name: _verify_is_array(substitutor.rec(ret)) - for name, ret in new_expr.function.returns.items()}, - tags=new_expr.tags + {name: _verify_is_array(self.rec(substitutor(ret))) + for name, ret in expr.function.returns.items()}, + tags=expr.tags ) else: - return new_expr + return super().map_call(expr) def map_named_call_result(self, expr: NamedCallResult) -> Array: new_call_or_inlined_expr = self.rec(expr._container) @@ -104,7 +203,11 @@ class InlineMarker(CopyMapper): Primary mapper for :func:`tag_all_calls_to_be_inlined`. """ def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: - return super().map_call(expr).tagged(InlineCallTag()) + rec_expr = super().map_call(expr) + if rec_expr.tags_of_type(InlineCallTag): + return rec_expr + else: + return rec_expr.tagged(InlineCallTag()) def inline_calls(expr: ArrayOrNames) -> ArrayOrNames: @@ -132,4 +235,2009 @@ def tag_all_calls_to_be_inlined(expr: ArrayOrNames) -> ArrayOrNames: # }}} + +# {{{ _collect_used_call_inputs + +class _UsedCallInputCollector(CachedWalkMapper[[]]): + def __init__( + self, + _fn_input_gatherers: + dict[FunctionDefinition, InputGatherer] | None = None, + _visited_functions: OrderedSet[Any] | None = None + ) -> None: + if _fn_input_gatherers is None: + _fn_input_gatherers = {} + + self.call_to_used_inputs: dict[Call, OrderedSet[Placeholder]] = {} + self._fn_input_gatherers = _fn_input_gatherers + + super().__init__(_visited_functions=_visited_functions) + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + _fn_input_gatherers=self._fn_input_gatherers, + _visited_functions=self._visited_functions) + + def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: + return expr + + def get_function_definition_cache_key( + self, expr: FunctionDefinition) -> FunctionDefinition: + return expr + + # type-ignore-reason: CachedWalkMapper's method takes in variadic args, kwargs + def map_named_call_result( + self, expr: NamedCallResult, # type: ignore[override] + ) -> None: + call = expr._container + try: + input_gatherer = self._fn_input_gatherers[call.function] + except KeyError: + input_gatherer = InputGatherer() + self._fn_input_gatherers[call.function] = input_gatherer + + used_inputs = self.call_to_used_inputs.setdefault(call, OrderedSet()) + used_inputs |= input_gatherer(call.function.returns[expr.name]) + + super().map_named_call_result(expr) + + +def _collect_used_call_inputs( + expr: ArrayOrNames) -> immutabledict[Call, FrozenOrderedSet[Placeholder]]: + """ + Returns a mapping from :class:`~pytato.function.Call` to the set of input + :class:`~pt.array.Placeholder`\ s belonging to its function definition that are + actually used by the expression. In other words, it returns the inputs + corresponding to the call bindings that would remain in the DAG if the call was + inlined. + """ + collector = _UsedCallInputCollector() + collector(expr) + + return immutabledict({ + call: FrozenOrderedSet(inputs) + for call, inputs in collector.call_to_used_inputs.items()}) + +# }}} + + +# {{{ zero_unused_call_bindings + +class _UnusedCallBindingZeroer(CopyMapper): + """ + Mapper to replace unused bindings of :class:`~pytato.function.Call` with zeros + of appropriate shape. + """ + def __init__( + self, + call_to_used_inputs: Mapping[Call, FrozenOrderedSet[Placeholder]], + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + super().__init__(_cache=_cache, _function_cache=_function_cache) + self.call_to_used_inputs = call_to_used_inputs + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + call_to_used_inputs=self.call_to_used_inputs, + _function_cache=self._function_cache) + + def map_call(self, expr: Call) -> Call: + new_function = self.rec_function_definition(expr.function) + new_bindings = {} + for name, bnd in expr.bindings.items(): + if isinstance(bnd, Array): + if ( + expr.function.get_placeholder(name) + in self.call_to_used_inputs[expr]): + new_bnd = self.rec(bnd) + else: + new_bnd = zeros(bnd.shape, bnd.dtype) + else: + new_bnd = bnd + new_bindings[name] = new_bnd + if ( + new_function is expr.function + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values()))): + return expr + else: + return Call(new_function, immutabledict(new_bindings), tags=expr.tags) + + +def zero_unused_call_bindings(expr: ArrayOrNames) -> ArrayOrNames: + """ + Replaces :class:`~pytato.function.Call` bindings that are not used by the + expression with arrays of zeros of the appropriate shape. This can be necessary + for certain transformations such as concatenation, where otherwise bindings + may be retained in the DAG when they should be dropped. + """ + call_to_used_inputs = _collect_used_call_inputs(expr) + return _UnusedCallBindingZeroer(call_to_used_inputs)(expr) + +# }}} + + +# {{{ Concatenatability + +@attrs.define(frozen=True) +class Concatenatability: + """ + Describes how a particular array expression can be concatenated. + """ + + +@attrs.define(frozen=True) +class ConcatableAlongAxis(Concatenatability): + """ + Used to describe an array expression that is concatenatable along *axis*. + """ + axis: int + + +@attrs.define(frozen=True) +class ConcatableIfConstant(Concatenatability): + """ + Used to describe an array expression in a function body that can be + concatenated only if the expression is the same across call-sites. + """ + +# }}} + + +# {{{ concatenate_calls + +@attrs.define(frozen=True) +class CallSiteLocation: + r""" + Records a call-site's location in a :mod:`pytato` expression. + + .. attribute:: call + + The instance of :class:`~pytato.function.Call` being called at this + location. + + .. attribute:: stack + + The call sites within which this particular call is called. + For eg. if ``stack = (c1, c2)``, then :attr:`call` is called within + ``c2``\ 's function body which itself is called from ``c1``\ 's + function body. + """ + call: Call + stack: Tuple[Call, ...] + + +class CallSiteDependencyCollector( + CombineMapper[FrozenOrderedSet[CallSiteLocation], Never]): + r""" + Collects all the call sites in a :mod:`pytato` expression along with their + interdependencies. + + .. attribute:: stack + + The stack of calls at which the calls are being collected. This + attribute is used to specify :attr:`CallSiteLocation.stack` in the + :class:`CallSiteLocation`\ s being built. Must be altered (by creating + a new instance of the mapper) before entering the function body of a + new :class:`~pytato.function.Call`. + + .. attribute:: call_site_to_dep_call_sites + + A mapping from call site to the call sites on which it depends, for each + call site present in the expression. + """ + def __init__(self, stack: Tuple[Call, ...]) -> None: + self.stack = stack + self.call_site_to_dep_call_sites: \ + Dict[CallSiteLocation, CallSiteLocation] = {} + super().__init__() + + def combine(self, *args: FrozenOrderedSet[CallSiteLocation] + ) -> FrozenOrderedSet[CallSiteLocation]: + return reduce(lambda a, b: a | b, args, FrozenOrderedSet()) + + def map_size_param(self, expr: SizeParam) -> FrozenOrderedSet[CallSiteLocation]: + return FrozenOrderedSet() + + def map_call(self, expr: Call) -> FrozenOrderedSet[CallSiteLocation]: + cs = CallSiteLocation(expr, self.stack) + + new_mapper_for_fn = CallSiteDependencyCollector(stack=self.stack + (expr,)) + dependent_call_sites = self.combine( + *[ + self.rec(bnd) for bnd in expr.bindings.values() + if isinstance(bnd, Array)], + *[new_mapper_for_fn(ret) + for ret in expr.function.returns.values()]) + + self.call_site_to_dep_call_sites[cs] = dependent_call_sites + self.call_site_to_dep_call_sites.update( + new_mapper_for_fn.call_site_to_dep_call_sites) + + return self.combine(FrozenOrderedSet([cs]), dependent_call_sites) + + +class _NamedCallResultReplacerPostConcatenate(CopyMapper): + """ + Mapper to replace instances of :class:`~pytato.function.NamedCallResult` as + per :attr:`replacement_map`. + + .. attribute:: current_stack + + Records the stack to track which function body the mapper is + traversing. Must be altered (by creating a new instance) before + entering the function body of a new :class:`~pytato.function.Call`. + """ + def __init__( + self, + replacement_map: Mapping[ + Tuple[ + NamedCallResult, + Tuple[Call, ...]], + Array], + current_stack: Tuple[Call, ...], + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + super().__init__(_cache=_cache, _function_cache=_function_cache) + self.replacement_map = replacement_map + self.current_stack = current_stack + + def clone_for_callee( + self, function: FunctionDefinition) -> Self: + raise AssertionError("Control should not reach here." + " Call clone_with_new_call_on_stack instead.") + + def clone_with_new_call_on_stack(self, expr: Call) -> Self: + # type-ignore-reason: Mapper class does not define these attributes. + return type(self)( # type: ignore[call-arg] + self.replacement_map, # type: ignore[attr-defined] + self.current_stack + (expr,), # type: ignore[attr-defined] + _function_cache=self._function_cache + ) + + def map_function_definition( + self, expr: FunctionDefinition) -> FunctionDefinition: + # No clone here because we're cloning in map_call instead + new_returns = {name: self.rec(ret) + for name, ret in expr.returns.items()} + if all( + new_ret is ret + for ret, new_ret in zip( + expr.returns.values(), + new_returns.values())): + return expr + else: + return attrs.evolve(expr, returns=immutabledict(new_returns)) + + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: + new_mapper = self.clone_with_new_call_on_stack(expr) + new_function = new_mapper.rec_function_definition(expr.function) + new_bindings = { + name: self.rec(bnd) if isinstance(bnd, Array) else bnd + for name, bnd in expr.bindings.items()} + if ( + new_function is expr.function + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values()))): + return expr + else: + return Call(new_function, immutabledict(new_bindings), tags=expr.tags) + + def map_named_call_result(self, expr: NamedCallResult) -> Array: + try: + new_expr = self.replacement_map[expr, self.current_stack] + if isinstance(new_expr, NamedCallResult): + return super().map_named_call_result(new_expr) + else: + return self.rec(new_expr) + except KeyError: + return super().map_named_call_result(expr) + + +def _have_same_axis_length(arrays: Collection[Array], + iaxis: int) -> bool: + """ + Returns *True* only if every array in *arrays* have the same axis length + along *iaxis*. + """ + axis_length = next(iter(arrays)).shape[iaxis] + return all(are_shape_components_equal(other_ary.shape[iaxis], + axis_length) + for other_ary in arrays) + + +def _have_same_axis_length_except(arrays: Collection[Array], + iaxis: int) -> bool: + """ + Returns *True* only if every array in *arrays* have the same + dimensionality and have axes with the same lengths except along the + *iaxis*-axis. + """ + ndim = next(iter(arrays)).ndim + return (all(ary.ndim == ndim for ary in arrays) + and all(_have_same_axis_length(arrays, idim) + for idim in range(ndim) + if idim != iaxis)) + + +@attrs.define(frozen=True) +class _InputConcatabilityGetterAcc: + r""" + Return type for :class:`_InputConcatabilityGetter`. An instance of this class is + returned after mapping a :class:`~pytato.Array` expression. + + .. attribute:: seen_inputs + + A :class:`FrozenOrderedSet` of all :class:`pytato.InputArgumentBase` + predecessors of a node. + + .. attribute:: input_concatability + + Records the constraints that come along with concatenating the array + being mapped. The constraints are recorded as a mapping from the axes + of the array being mapped to the axes of the input arguments. This + mapping informs us which axes in the :class:`InputArgumentBase`\ s' + must be concatenated to soundly concatenate a particular axis in the + array being mapped. The axes in this mapping are represented using + :class:`int`. If certain axes are missing in this mapping, then + concatenation cannot be performed along those axes for the mapped + array. + """ + seen_inputs: FrozenOrderedSet[InputArgumentBase] + input_concatability: Mapping[Concatenatability, + Mapping[InputArgumentBase, Concatenatability]] + + def __post_init__(self) -> None: + assert all( + FrozenOrderedSet(input_concat.keys()) == self.seen_inputs + for input_concat in self.input_concatability.values()) + + __attrs_post_init__ = __post_init__ + + +class NonConcatableExpression(RuntimeError): + """ + Used internally by :class:`_ScalarExprConcatabilityMapper`. + """ + + +class _InvalidConcatenatability(RuntimeError): + """ + Used internally by :func:`_get_ary_to_concatenatabilities`. + """ + + +class _ScalarExprConcatabilityMapper(scalar_expr.CombineMapper): + """ + Maps :attr:`~pytato.array.IndexLambda.expr` to the axes of the bindings + that must be concatenated to concatenate the IndexLambda's + :attr:`iaxis`-axis. + + .. attribute:: allow_indirect_addr + + If *True* indirect access are allowed. However, concatenating along the + iaxis-axis would be sound only if the binding which is being indexed + into is same for all the expressions to be concatenated. + """ + def __init__(self, iaxis: int, allow_indirect_addr: bool) -> None: + self.iaxis = iaxis + self.allow_indirect_addr = allow_indirect_addr + super().__init__() + + def combine(self, values: Collection[Mapping[str, Concatenatability]] + ) -> Mapping[str, Concatenatability]: + result: Dict[str, Concatenatability] = {} + for value in values: + for bnd_name, iaxis in value.items(): + try: + if result[bnd_name] != iaxis: + # only one axis of a particular binding can be + # concatenated. If multiple axes must be concatenated + # that means the index lambda is not + # iaxis-concatenatable. + raise NonConcatableExpression + except KeyError: + result[bnd_name] = iaxis + + return immutabledict(result) + + def map_variable(self, expr: prim.Variable) -> Mapping[str, Concatenatability]: + if expr.name == f"_{self.iaxis}": + raise NonConcatableExpression + else: + return immutabledict() + + def map_constant(self, expr: Any) -> Mapping[str, Concatenatability]: + return immutabledict() + + map_nan = map_constant + + def map_subscript(self, expr: prim.Subscript + ) -> Mapping[str, Concatenatability]: + name: str = expr.aggregate.name + rec_indices: List[Mapping[str, Concatenatability]] = [] + for iaxis, idx in enumerate(expr.index_tuple): + if idx == prim.Variable(f"_{self.iaxis}"): + rec_indices.append({name: ConcatableAlongAxis(iaxis)}) + else: + rec_idx = self.rec(idx) + if rec_idx: + if not self.allow_indirect_addr: + raise NonConcatableExpression + else: + # indirect accesses cannot be concatenated in the general + # case unless the indexee is the same for the + # expression graphs being concatenated. + pass + rec_indices.append(rec_idx) + + combined_rec_indices = dict(self.combine(rec_indices)) + + if name not in combined_rec_indices: + combined_rec_indices[name] = ConcatableIfConstant() + + return immutabledict(combined_rec_indices) + + +@memoize_on_first_arg +def _get_binding_to_concatenatability_scalar_expr( + expr: scalar_expr.ScalarExpression, + iaxis: int, + allow_indirect_addr: bool) -> Mapping[str, Concatenatability]: + mapper = _ScalarExprConcatabilityMapper(iaxis, allow_indirect_addr) + return mapper(expr) # type: ignore[no-any-return] + + + +def _get_binding_to_concatenatability(expr: scalar_expr.ScalarExpression, + iaxis: int, + allow_indirect_addr: bool, + ) -> Mapping[str, Concatenatability]: + """ + Maps *expr* using :class:`_ScalarExprConcatabilityMapper`. + """ + if np.isscalar(expr): + # In some cases expr may just be a number, which can't be memoized on + return {} + + return _get_binding_to_concatenatability_scalar_expr( + expr, iaxis, allow_indirect_addr) + + +def _combine_input_accs( + operand_accs: Tuple[_InputConcatabilityGetterAcc, ...], + concat_to_operand_concatabilities: Mapping[Concatenatability, + Tuple[Concatenatability, ...] + ], +) -> _InputConcatabilityGetterAcc: + """ + For an index lambda ``I`` with operands ``I1, I2, .. IN`` that specify their + concatenatability constraints using *operand_accs*, this routine returns + the axes concatenation constaints of ``I``. + + :arg concat_to_operand_concatabilities: Mapping of the form ``concat_I -> + (C_I1, C_I2, ..., C_IN)`` specifying the concatabilities of the + operands ``I1, I2, .., IN`` in order to concatenate the + ``I`` axis via the criterion ``conncat_I``. + """ + + input_concatabilities: Dict[Concatenatability, Mapping[InputArgumentBase, + Concatenatability]] = {} + seen_inputs: FrozenOrderedSet[InputArgumentBase] = reduce( + FrozenOrderedSet.union, + (operand_acc.seen_inputs for operand_acc in operand_accs), + FrozenOrderedSet()) + + # The core logic here is to filter the iaxis in out_axis_to_operand_axes + # so that all the operands agree on how the input arguments must be + # concatenated. + + for out_concat, operand_concatabilities in (concat_to_operand_concatabilities + .items()): + is_i_out_axis_concatenatable = True + input_concatability: Dict[InputArgumentBase, Concatenatability] = {} + + for operand_concatability, operand_acc in zip(operand_concatabilities, + operand_accs, + strict=True): + if operand_concatability not in ( + operand_acc.input_concatability): + # required operand concatability cannot be achieved + # => out_concat cannot be concatenated + is_i_out_axis_concatenatable = False + break + + for input_arg, input_concat in ( + operand_acc + .input_concatability[operand_concatability] + .items()): + try: + if input_concatability[input_arg] != input_concat: + is_i_out_axis_concatenatable = False + break + except KeyError: + input_concatability[input_arg] = input_concat + if not is_i_out_axis_concatenatable: + break + + if is_i_out_axis_concatenatable: + input_concatabilities[out_concat] = immutabledict(input_concatability) + + return _InputConcatabilityGetterAcc(seen_inputs, + immutabledict(input_concatabilities)) + + +@attrs.define(frozen=True) +class FunctionConcatenability: + r""" + Records a valid concatenatability criterion for a + :class:`pytato.function.FunctionDefinition`. + + .. attribute:: output_to_concatenatability + + A mapping from the name of a + :class:`FunctionDefinition`\ 's returned array to how it should be + concatenated. + + .. attribute:: input_to_concatenatability + + A mapping from a :class:`FunctionDefinition`\ 's parameter to how it + should be concatenated. + + .. note:: + + A :class:`FunctionDefinition` typically has multiple valid + concatenability constraints. This class only records one of those valid + constraints. + """ + output_to_concatenatability: Mapping[str, Concatenatability] + input_to_concatenatability: Mapping[str, Concatenatability] + + def __str__(self) -> str: + outputs = [] + for name, concat in self.output_to_concatenatability.items(): + outputs.append(f"{name} => {concat}") + + inputs = [] + for name, concat in self.input_to_concatenatability.items(): + inputs.append(f"{name} => {concat}") + + output_str = "\n".join(outputs) + input_str = "\n".join(inputs) + + return (f"Outputs:\n--------\n{output_str}\n" + f"========\nInputs:\n-------\n{input_str}\n" + "========") + + +def _combine_named_result_accs_simple( + named_result_accs: Mapping[str, _InputConcatabilityGetterAcc] +) -> Tuple[FunctionConcatenability, ...]: + """ + Combines the concantenatability constraints of named results of a + :class:`FunctionDefinition` and returns a :class:`tuple` of the valid + *simple* concatenatable constraints (i.e., concatenation of all inputs/outputs + along the same axis). + """ + valid_concatenatabilities: List[FunctionConcatenability] = [] + + input_args = reduce( + FrozenOrderedSet.union, + [ + acc.seen_inputs + for acc in named_result_accs.values()], + FrozenOrderedSet()) + + candidate_concat_axes = reduce( + FrozenOrderedSet.union, + [ + FrozenOrderedSet(acc.input_concatability.keys()) + for acc in named_result_accs.values()], + FrozenOrderedSet()) + + # print(f"{candidate_concat_axes=}") + + for i_concat_axis in candidate_concat_axes: + # if isinstance(i_concat_axis, ConcatableAlongAxis) and i_concat_axis.axis == 0: + # for acc in named_result_accs.values(): + # for ary, concat in acc.input_concatability[i_concat_axis].items(): + # print(f"{type(ary).__name__=}, {ary.name=}, {ary.shape=}, {id(ary)=}, {concat=}") + # print("") + if ( + all( + i_concat_axis in acc.input_concatability + for acc in named_result_accs.values()) + and all( + ( + i_input_axis == i_concat_axis + or isinstance(i_input_axis, ConcatableIfConstant)) + for acc in named_result_accs.values() + for i_input_axis in ( + acc.input_concatability[i_concat_axis].values()))): + output_concats = {name: i_concat_axis for name in named_result_accs} + input_concats = {pl.name: i_concat_axis + for pl in input_args + if isinstance(pl, Placeholder)} + valid_concatenatabilities.append( + FunctionConcatenability(immutabledict(output_concats), + immutabledict(input_concats))) + + return valid_concatenatabilities + + +# FIXME: Find a more efficient way to do this. The number of candidates +# explodes when the function being concatenated has more than a few outputs +def _combine_named_result_accs_exhaustive( + named_result_accs: Mapping[str, _InputConcatabilityGetterAcc] +) -> Generator[ + FunctionConcatenability, + None, + None]: + """ + Combines the concantenatability constraints of named results of a + :class:`FunctionDefinition` and returns a :class:`tuple` of the valid + concatenatable constraints. + """ + potential_concatenatable_output_axes = itertools.product(*[ + [(name, concat) for concat in acc.input_concatability] + for name, acc in named_result_accs.items()]) + + for output_concats in potential_concatenatable_output_axes: + is_concatenatable = True + input_concatability: Dict[InputArgumentBase, Concatenatability] = {} + + for result_name, iresult_axis in output_concats: + for input_arg, i_input_axis in ( + named_result_accs[result_name] + .input_concatability[iresult_axis] + .items()): + try: + if input_concatability[input_arg] != i_input_axis: + is_concatenatable = False + break + except KeyError: + input_concatability[input_arg] = i_input_axis + + if not is_concatenatable: + break + + if is_concatenatable: + pl_concatabilities = {pl.name: concat + for pl, concat in input_concatability.items() + if isinstance(pl, Placeholder)} + yield FunctionConcatenability(immutabledict(output_concats), + immutabledict(pl_concatabilities)) + + +class _InputConcatabilityGetter( + CachedMapper[ArrayOrNames, Never, [ArrayOrNames, ...]]): + """ + Maps :class:`pytato.array.Array` expressions to + :class:`_InputConcatenatabilityGetterAcc` that summarizes constraints + induced on the concatenatability of the inputs of the expression by the + expression's concatenatability. + """ + def get_cache_key( + self, expr: ArrayOrNames, *exprs_from_other_calls: ArrayOrNames + ) -> tuple[ArrayOrNames, ...]: + return (expr, *exprs_from_other_calls) + + def _map_input_arg_base( + self, + expr: InputArgumentBase, + *exprs_from_other_calls: InputArgumentBase, + ) -> _InputConcatabilityGetterAcc: + input_concatenatability: Dict[Concatenatability, + Mapping[InputArgumentBase, + Concatenatability]] = {} + for idim in range(expr.ndim): + input_concatenatability[ConcatableAlongAxis(idim)] = immutabledict( + {expr: ConcatableAlongAxis(idim)}) + + input_concatenatability[ConcatableIfConstant()] = immutabledict( + {expr: ConcatableIfConstant()}) + + return _InputConcatabilityGetterAcc(FrozenOrderedSet([expr]), + immutabledict(input_concatenatability)) + + map_placeholder = _map_input_arg_base + map_data_wrapper = _map_input_arg_base + + def _map_index_lambda_like( + self, + expr: Array, + *exprs_from_other_calls: Array, + allow_indirect_addr: bool) -> _InputConcatabilityGetterAcc: + expr = to_index_lambda(expr) + exprs_from_other_calls = tuple( + to_index_lambda(ary) for ary in exprs_from_other_calls) + + input_accs = tuple( + self.rec( + expr.bindings[name], + *[ + ary.bindings[name] + for ary in exprs_from_other_calls]) + for name in sorted(expr.bindings.keys())) + expr_concat_to_input_concats: Dict[Concatenatability, + Tuple[Concatenatability, ...]] = {} + + for iaxis in range(expr.ndim): + for ary in (expr,) + exprs_from_other_calls: + # If the array has length 1 along this axis, the index may have been + # dropped from the scalar expression, in which case + # _get_binding_to_concatenatability will fail to determine the + # concatenatability. If that happens, we have to look at the other + # expressions in the hope that one of them has a non-1 length + if ary.shape[iaxis] == 1: + continue + try: + bnd_name_to_concat = _get_binding_to_concatenatability( + ary.expr, iaxis, allow_indirect_addr) + expr_concat_to_input_concats[ConcatableAlongAxis(iaxis)] = ( + tuple(concat + for _, concat in sorted(bnd_name_to_concat.items(), + key=lambda x: x[0])) + ) + except NonConcatableExpression: + # print(f"{iaxis=}") + # print(f"{ary.expr=}") + # print(f"{ary.shape=}") + break + + expr_concat_to_input_concats[ConcatableIfConstant()] = tuple( + ConcatableIfConstant() for _ in expr.bindings) + + return _combine_input_accs(input_accs, expr_concat_to_input_concats) + + map_index_lambda = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_einsum = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_basic_index = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_roll = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_stack = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_concatenate = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_axis_permutation = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_reshape = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + + map_contiguous_advanced_index = partialmethod(_map_index_lambda_like, + allow_indirect_addr=True) + map_non_contiguous_advanced_index = partialmethod(_map_index_lambda_like, + allow_indirect_addr=True) + + def map_named_call_result( + self, + expr: NamedCallResult, + *exprs_from_other_calls: NamedCallResult, + ) -> _InputConcatabilityGetterAcc: + raise NotImplementedError("nested functions aren't supported.") + + # FIXME: Update the code below to work after changing + # _InputConcatabilityGetter to look at all function calls instead of just + # the template call + assert isinstance(expr._container, Call) + valid_concatenatabilities = _get_valid_concatenatability_constraints_simple( + expr._container.function) + + expr_concat_possibilities = FrozenOrderedSet( + valid_concatenability.output_to_concatenatability[expr.name] + for valid_concatenability in valid_concatenatabilities + ) + + input_concatenatabilities: Dict[Concatenatability, + Mapping[InputArgumentBase, + Concatenatability]] = {} + rec_bindings = {bnd_name: self.rec(binding) + for bnd_name, binding in expr._container.bindings.items()} + callee_acc = self.rec(expr._container.function.returns[expr.name]) + seen_inputs: OrderedSet[InputArgumentBase] = OrderedSet() + + for seen_input in callee_acc.seen_inputs: + if isinstance(seen_input, Placeholder): + seen_inputs.update(rec_bindings[seen_input.name].seen_inputs) + elif isinstance(seen_input, (DataWrapper, SizeParam)): + seen_inputs.add(seen_input) + else: + raise NotImplementedError(type(seen_input)) + + for concat_possibility in expr_concat_possibilities: + caller_input_concatabilities: Dict[InputArgumentBase, + Concatenatability] = {} + + is_concat_possibility_valid = True + for callee_input_arg, callee_input_concat in ( + callee_acc.input_concatability[concat_possibility].items()): + caller_acc = rec_bindings[callee_input_arg.name] + if isinstance(callee_input_arg, Placeholder): + if callee_input_concat in caller_acc.input_concatability: + for caller_input_arg, caller_input_concat in ( + caller_acc + .input_concatability[callee_input_concat] + .items()): + try: + if (caller_input_concatabilities[caller_input_arg] + != caller_input_concat): + is_concat_possibility_valid = False + break + except KeyError: + caller_input_concatabilities[callee_input_arg] = ( + caller_input_concat) + if not is_concat_possibility_valid: + break + else: + is_concat_possibility_valid = False + break + elif isinstance(callee_input_arg, (DataWrapper, SizeParam)): + try: + if (caller_input_concatabilities[callee_input_arg] + != callee_input_concat): + is_concat_possibility_valid = False + break + except KeyError: + caller_input_concatabilities[callee_input_arg] = ( + callee_input_concat) + else: + raise NotImplementedError(type(callee_input_arg)) + + if is_concat_possibility_valid: + input_concatenatabilities[concat_possibility] = immutabledict( + caller_input_concatabilities) + + return _InputConcatabilityGetterAcc(FrozenOrderedSet(seen_inputs), + immutabledict(input_concatenatabilities)) + + def map_loopy_call_result( + self, + expr: "LoopyCallResult", + *exprs_from_other_calls: "LoopyCallResult", + ) -> _InputConcatabilityGetterAcc: + raise ValueError("Loopy Calls are illegal to concatenate. Maybe" + " rewrite the operation as array operations?") + + +def _verify_arrays_can_be_concated_along_axis( + arrays: Collection[Array], + fields_that_must_be_same: Collection[str], + iaxis: int) -> None: + """ + Performs some common checks if *arrays* from different function bodies can be + concatenated. + + .. attribute:: arrays + + Corresponding expressions from the function bodies for call-site that + are being checked for concatenation along *iaxis*. + """ + if not _have_same_axis_length_except(arrays, iaxis): + raise _InvalidConcatenatability("Axis lengths are incompatible.") + for field in fields_that_must_be_same: + if len({getattr(ary, field) for ary in arrays}) != 1: + raise _InvalidConcatenatability(f"Field '{field}' varies across calls.") + + +def _verify_arrays_same(arrays: Collection[Array]) -> None: + if len(set(arrays)) != 1: + raise _InvalidConcatenatability("Arrays are not the same.") + + +def _get_concatenated_shape(arrays: Collection[Array], iaxis: int) -> ShapeType: + # type-ignore-reason: mypy expects 'ary.shape[iaxis]' as 'int' since the + # 'start' is an 'int' + concatenated_axis_length = sum(ary.shape[iaxis] # type: ignore[misc] + for ary in arrays) + template_ary = next(iter(arrays)) + + return tuple(dim + if idim != iaxis + else concatenated_axis_length + for idim, dim in enumerate(template_ary.shape) + ) + + +class _ConcatabilityCollector(CachedWalkMapper): + def __init__( + self, + current_stack: Tuple[Call, ...], + _visited_functions: OrderedSet[Any] | None = None + ) -> None: + self.ary_to_concatenatability: Dict[ArrayOnStackT, Concatenatability] = {} + self.current_stack = current_stack + self.call_sites_on_hold: OrderedSet[Call] = OrderedSet() + super().__init__(_visited_functions=_visited_functions) + + # type-ignore-reason: CachedWalkMaper takes variadic `*args, **kwargs`. + def get_cache_key(self, # type: ignore[override] + expr: ArrayOrNames, + *args: Any, + ) -> Tuple[ArrayOrNames, Any]: + return (expr, args) + + # type-ignore-reason: CachedWalkMaper takes variadic `*args, **kwargs`. + def get_function_definition_cache_key( + self, # type: ignore[override] + expr: FunctionDefinition, + *args: Any, + ) -> tuple[ArrayOrNames, Any]: + return (expr, args) + + def _record_concatability(self, expr: Array, + concatenatability: Concatenatability, + ) -> None: + key = (self.current_stack, expr) + assert key not in self.ary_to_concatenatability + self.ary_to_concatenatability[key] = concatenatability + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + raise AssertionError("Control should not reach here." + " Call clone_with_new_call_on_stack instead.") + + def clone_with_new_call_on_stack(self, expr: Call) -> Self: + # type-ignore-reason: Mapper class does not define these attributes. + return type(self)( # type: ignore[call-arg] + self.current_stack + (expr,), # type: ignore[attr-defined] + _visited_functions=self._visited_functions + ) + + def _map_input_arg_base(self, + expr: InputArgumentBase, + concatenatability: Concatenatability, + exprs_from_other_calls: Tuple[Array, ...], + ) -> None: + if isinstance(concatenatability, ConcatableIfConstant): + _verify_arrays_same((expr,) + exprs_from_other_calls) + elif isinstance(concatenatability, ConcatableAlongAxis): + # FIXME: Probably needs some extra handling for broadcastable arrays + _verify_arrays_can_be_concated_along_axis( + (expr,) + exprs_from_other_calls, + ["dtype", "name"], + concatenatability.axis) + else: + raise NotImplementedError(type(concatenatability)) + + self._record_concatability(expr, concatenatability) + + map_placeholder = _map_input_arg_base # type: ignore[assignment] + map_data_wrapper = _map_input_arg_base # type: ignore[assignment] + + def _map_index_lambda_like(self, + expr: Array, + concatenatability: Concatenatability, + exprs_from_other_calls: Tuple[Array, ...], + allow_indirect_addr: bool, + ) -> None: + self._record_concatability(expr, concatenatability) + + idx_lambda = to_index_lambda(expr) + idx_lambdas_from_other_calls = tuple(to_index_lambda(ary) + for ary in exprs_from_other_calls) + + if isinstance(concatenatability, ConcatableIfConstant): + _verify_arrays_same((idx_lambda,) + idx_lambdas_from_other_calls) + for bnd_name in idx_lambda.bindings: + self.rec( + idx_lambda.bindings[bnd_name], concatenatability, + tuple( + ary.bindings[bnd_name] + for ary in idx_lambdas_from_other_calls)) + elif isinstance(concatenatability, ConcatableAlongAxis): + _verify_arrays_can_be_concated_along_axis( + (idx_lambda, ) + idx_lambdas_from_other_calls, + ["dtype"], + concatenatability.axis) + if len({ + ary.expr + for ary in (idx_lambda,) + idx_lambdas_from_other_calls + if ary.shape[concatenatability.axis] != 1}) != 1: + raise _InvalidConcatenatability( + "Cannot concatenate the calls; required fields are not the same.") + bnd_name_to_concat = None + for ary in (idx_lambda,) + idx_lambdas_from_other_calls: + if ary.shape[concatenatability.axis] > 1: + bnd_name_to_concat = _get_binding_to_concatenatability( + ary.expr, concatenatability.axis, allow_indirect_addr) + break + if bnd_name_to_concat is None: + bnd_name_to_concat = _get_binding_to_concatenatability( + idx_lambda.expr, concatenatability.axis, allow_indirect_addr) + for bnd_name, bnd_concat in bnd_name_to_concat.items(): + self.rec(idx_lambda.bindings[bnd_name], bnd_concat, + tuple(ary.bindings[bnd_name] + for ary in idx_lambdas_from_other_calls)) + else: + raise NotImplementedError(type(concatenatability)) + + map_index_lambda = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_einsum = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_basic_index = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_roll = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_stack = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_concatenate = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_axis_permutation = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_reshape = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + + map_contiguous_advanced_index = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=True) + map_non_contiguous_advanced_index = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=True) + + # type-ignore-reason: CachedWalkMapper.map_call takes in variadic args, kwargs + def map_call(self, # type: ignore[override] + expr: Call, + exprs_from_other_calls: Tuple[Call, ...]) -> None: + if not all( + (self.current_stack, named_result) in self.ary_to_concatenatability + for named_result in expr.values()): + self.call_sites_on_hold.add(expr) + else: + self.call_sites_on_hold.remove(expr) + # FIXME The code below bypasses caching of function definitions + new_mapper = self.clone_with_new_call_on_stack(expr) + for name, val_in_callee in expr.function.returns.items(): + new_mapper(val_in_callee, + self.ary_to_concatenatability[(self.current_stack, + expr[name])], + tuple(other_call.function.returns[name] + for other_call in exprs_from_other_calls) + ) + + if new_mapper.call_sites_on_hold: + raise NotImplementedError("Call sites that do not all use all" + " the returned values not yet" + " supported for concatenation.") + + for ary, concat in new_mapper.ary_to_concatenatability.items(): + assert ary not in self.ary_to_concatenatability + self.ary_to_concatenatability[ary] = concat + + for name, binding in expr.bindings.items(): + if not isinstance(binding, Array): + continue + concat = ( + new_mapper + .ary_to_concatenatability[(self.current_stack + (expr,), + expr.function.get_placeholder(name))] + ) + self.rec(binding, + concat, + tuple(other_call.bindings[name] + for other_call in exprs_from_other_calls)) + + # type-ignore-reason: CachedWalkMapper's method takes in variadic args, kwargs + def map_named_call_result(self, expr: NamedCallResult, # type: ignore[override] + concatenatability: Concatenatability, + exprs_from_other_calls: Tuple[Array, ...], + ) -> None: + self._record_concatability(expr, concatenatability) + if any(not isinstance(ary, NamedCallResult) + for ary in exprs_from_other_calls): + raise _InvalidConcatenatability() + + # type-ignore-reason: mypy does not respect the conditional which + # asserts that all arrays in `exprs_from_other_calls` are + # NamedCallResult. + self.rec(expr._container, + tuple(ary._container # type: ignore[attr-defined] + for ary in exprs_from_other_calls) + ) + + def map_loopy_call_result(self, expr: "LoopyCallResult" + ) -> None: + raise ValueError("Loopy Calls are illegal to concatenate. Maybe" + " rewrite the operation as array operations?") + + +# Memoize the creation of concatenated input arrays to avoid copies +class _InputConcatenator: + def __init__(self, inherit_axes: bool): + self.inherit_axes = inherit_axes + + @memoize_method + def __call__(self, arrays, axis): + if self.inherit_axes: + concat_axis_tag = UseInputAxis(0, axis) + else: + concat_axis_tag = ConcatenatedCallInputConcatAxisTag() + return concatenate( + arrays, + axis + ).with_tagged_axis(axis, frozenset({concat_axis_tag})).tagged( + ImplStored()) + + +# Memoize the creation of sliced output arrays to avoid copies +class _OutputSlicer: + def __init__(self, inherit_axes: bool): + self.inherit_axes = inherit_axes + + @memoize_method + def _get_slice( + self, + ary: Array, + axis: int, + start_idx: ShapeComponent, + end_idx: ShapeComponent): + indices = [slice(None) for i in range(ary.ndim)] + indices[axis] = slice(start_idx, end_idx) + if self.inherit_axes: + slice_axis_tag = UseInputAxis(None, axis) + else: + slice_axis_tag = ConcatenatedCallOutputSliceAxisTag() + sliced_ary = ary[tuple(indices)].with_tagged_axis( + axis, frozenset({slice_axis_tag})).tagged(ImplStored()) + assert isinstance(sliced_ary, BasicIndex) + return sliced_ary + + def __call__(self, ary, axis, slice_sizes): + start_indices: List[ShapeComponent] = [] + end_indices: List[ShapeComponent] = [] + if len(slice_sizes) > 0: + start_indices.append(0) + end_indices.append(slice_sizes[0]) + for islice in range(1, len(slice_sizes)): + start_indices.append(end_indices[-1]) + end_indices.append(end_indices[-1] + slice_sizes[islice]) + return [ + self._get_slice(ary, axis, start_idx, end_idx) + for start_idx, end_idx in zip(start_indices, end_indices)] + + +class _FunctionConcatenator(TransformMapperWithExtraArgs[Tuple[Array, ...]]): + def __init__(self, + current_stack: Tuple[Call, ...], + input_concatenator: _InputConcatenator, + ary_to_concatenatability: Mapping[ArrayOnStackT, Concatenatability], + _cache: TransformMapperCache[ + ArrayOrNames, [Tuple[Array, ...]]] | None = None, + _function_cache: TransformMapperCache[ + FunctionDefinition, [Tuple[Array, ...]]] | None = None + ) -> None: + super().__init__(_cache=_cache, _function_cache=_function_cache) + self.current_stack = current_stack + self.input_concatenator = input_concatenator + self.ary_to_concatenatability = ary_to_concatenatability + + def get_cache_key( + self, expr: ArrayOrNames, exprs_from_other_calls: tuple[Array, ...] + ) -> tuple[ArrayOrNames, ...]: + return (expr, *exprs_from_other_calls) + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + raise AssertionError("Control should not reach here." + " Call clone_with_new_call_on_stack instead.") + + def clone_with_new_call_on_stack(self, expr: Call) -> Self: + return type(self)( + self.current_stack + (expr,), + self.input_concatenator, + self.ary_to_concatenatability, + _function_cache=self._function_cache + ) + + def _get_concatenatability(self, expr: Array) -> Concatenatability: + return self.ary_to_concatenatability[(self.current_stack, expr)] + + def map_placeholder(self, + expr: Placeholder, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + concat = self._get_concatenatability(expr) + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + new_shape = _get_concatenated_shape( + (expr,) + exprs_from_other_calls, concat.axis) + return Placeholder(name=expr.name, + dtype=expr.dtype, + shape=new_shape, + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + def map_data_wrapper(self, + expr: DataWrapper, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + concat = self._get_concatenatability(expr) + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + return self.input_concatenator( + (expr,) + exprs_from_other_calls, concat.axis) + else: + raise NotImplementedError(type(concat)) + + def map_index_lambda(self, + expr: IndexLambda, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + concat = self._get_concatenatability(expr) + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, IndexLambda) + for ary in exprs_from_other_calls) + + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are IndexLambda. + new_bindings = { + bnd_name: self.rec( + subexpr, + tuple(ary.bindings[bnd_name] # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + for bnd_name, subexpr in expr.bindings.items() + } + new_shape = _get_concatenated_shape((expr,) + exprs_from_other_calls, + concat.axis) + return IndexLambda(expr=expr.expr, + shape=new_shape, + dtype=expr.dtype, + bindings=immutabledict(new_bindings), + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + def map_einsum(self, expr: Einsum, + exprs_from_other_calls: Tuple[Array, ...]) -> Array: + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, Einsum) for ary in exprs_from_other_calls) + + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are Einsum. + new_args = [self.rec(arg, + tuple(ary.args[iarg] # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + for iarg, arg in enumerate(expr.args)] + + return Einsum(expr.access_descriptors, + tuple(new_args), + expr.redn_axis_to_redn_descr, + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + def _map_index_base(self, expr: IndexBase, + exprs_from_other_calls: Tuple[Array, ...]) -> Array: + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, IndexBase) for ary in exprs_from_other_calls) + + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are IndexBase. + new_indices = [ + self.rec(idx, + tuple(ary.indices[i_idx] # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + if isinstance(idx, Array) + else idx + for i_idx, idx in enumerate(expr.indices) + ] + new_array = self.rec(expr.array, + tuple(ary.array # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + + return type(expr)(array=new_array, + indices=tuple(new_indices), + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + map_contiguous_advanced_index = _map_index_base + map_non_contiguous_advanced_index = _map_index_base + map_basic_index = _map_index_base + + def map_roll(self, + expr: Roll, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert concat.axis != expr.axis + assert all(isinstance(ary, Roll) for ary in exprs_from_other_calls) + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are Roll. + return Roll(self.rec(expr.array, + tuple(ary.array # type: ignore[attr-defined] + for ary in exprs_from_other_calls)), + shift=expr.shift, + axis=expr.axis, + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + def map_stack(self, expr: Stack, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, Stack) for ary in exprs_from_other_calls) + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are Stack. + if any(len(ary.arrays) != len(expr.arrays) # type: ignore[attr-defined] + for ary in exprs_from_other_calls): + raise ValueError("Cannot concatenate stack expressions" + " with different number of arrays.") + + new_arrays = tuple( + self.rec(array, + tuple(subexpr.arrays[iarray] # type: ignore[attr-defined] + for subexpr in exprs_from_other_calls) + ) + for iarray, array in enumerate(expr.arrays)) + + return Stack(new_arrays, + expr.axis, + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + def map_concatenate(self, expr: Concatenate, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, Concatenate) + for ary in exprs_from_other_calls) + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are Concatenate. + if any(len(ary.arrays) != len(expr.arrays) # type: ignore[attr-defined] + for ary in exprs_from_other_calls): + raise ValueError("Cannot concatenate concatenate-expressions" + " with different number of arrays.") + + new_arrays = tuple( + self.rec(array, + tuple(subexpr.arrays[iarray] # type: ignore[attr-defined] + for subexpr in exprs_from_other_calls) + ) + for iarray, array in enumerate(expr.arrays) + ) + + return Concatenate(new_arrays, + expr.axis, + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + def map_axis_permutation(self, expr: AxisPermutation, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, AxisPermutation) + for ary in exprs_from_other_calls) + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are AxisPermutation. + new_array = self.rec(expr.array, + tuple(ary.array # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + return AxisPermutation(new_array, + expr.axis_permutation, + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + def map_reshape(self, expr: Reshape, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + new_newshape = _get_concatenated_shape( + (expr,) + exprs_from_other_calls, concat.axis) + + assert all(isinstance(ary, Reshape) for ary in exprs_from_other_calls) + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are Reshape. + new_array = self.rec(expr.array, + tuple(ary.array # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + return Reshape(new_array, + new_newshape, + expr.order, + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + def map_function_definition( + self, + expr: FunctionDefinition, + exprs_from_other_calls: Tuple[FunctionDefinition, ...] + ) -> FunctionDefinition: + # No clone here because we're cloning in map_call instead + new_returns = { + name: self.rec( + ret, + tuple( + other_expr.returns[name] + for other_expr in exprs_from_other_calls)) + for name, ret in expr.returns.items()} + if all( + new_ret is ret + for ret, new_ret in zip( + expr.returns.values(), + new_returns.values())): + return expr + else: + return attrs.evolve(expr, returns=immutabledict(new_returns)) + + def map_call(self, expr: Call, other_callsites: Tuple[Call, ...]) -> Call: + new_mapper = self.clone_with_new_call_on_stack(expr) + new_function = new_mapper.rec_function_definition( + expr.function, + tuple(other_call.function for other_call in other_callsites)) + new_bindings = {name: ( + self.rec( + bnd, tuple( + callsite.bindings[name] + for callsite in other_callsites)) + if isinstance(bnd, Array) + else bnd) + for name, bnd in expr.bindings.items()} + if ( + new_function is expr.function + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values()))): + return expr + else: + return Call(new_function, immutabledict(new_bindings), tags=expr.tags) + + def map_named_call_result(self, + expr: NamedCallResult, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, NamedCallResult) + for ary in exprs_from_other_calls) + assert isinstance(expr._container, Call) + new_call = self.rec( + expr._container, + tuple(ary._container # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + return new_call[expr.name] + else: + raise NotImplementedError(type(concat)) + + def map_loopy_call_result(self, expr: "LoopyCallResult", + exprs_from_other_calls: Tuple[Array, ...], + ) -> _InputConcatabilityGetterAcc: + raise ValueError("Loopy Calls are illegal to concatenate. Maybe" + " rewrite the operation as array operations?") + + +@memoize_on_first_arg +def _get_valid_concatenatability_constraints_simple( + template_call: Call, *other_calls: Call) -> Tuple[FunctionConcatenability]: + template_fn = template_call.function + mapper = _InputConcatabilityGetter() + output_accs = { + name: mapper( + *[cs.function.returns[name] for cs in (template_call,) + other_calls]) + for name in template_fn.returns} + + return _combine_named_result_accs_simple(output_accs) + + +@memoize_on_first_arg +def _get_valid_concatenatability_constraints_exhaustive( + fn: FunctionDefinition) -> Generator[ + FunctionConcatenability, + None, + None]: + mapper = _InputConcatabilityGetter() + output_accs = {name: mapper(output) + for name, output in fn.returns.items()} + + yield from _combine_named_result_accs_exhaustive(output_accs) + + +def _get_ary_to_concatenatabilities(call_sites: Sequence[Call], + ) -> Generator[Mapping[ArrayOnStackT, + Concatenatability], + None, + None]: + """ + Generates a :class:`Concatenatability` criterion for each array in the + expression graph of *call_sites*'s function body if they traverse identical + function bodies. + """ + fn_concatenatabilities = \ + _get_valid_concatenatability_constraints_simple(*call_sites) + + # select a template call site to start the traversal. + template_call, *other_calls = call_sites + template_fn = template_call.function + fid = next(iter(template_fn.tags_of_type(FunctionIdentifier))) + + concat_idx_to_err_msg = {} + + for iconcat, fn_concatenatability in enumerate(fn_concatenatabilities): + collector = _ConcatabilityCollector(current_stack=()) + + try: + # verify the constraints on parameters are satisfied + for name, input_concat in (fn_concatenatability + .input_to_concatenatability + .items()): + try: + if isinstance(input_concat, ConcatableIfConstant): + _verify_arrays_same([cs.bindings[name] for cs in call_sites]) + elif isinstance(input_concat, ConcatableAlongAxis): + _verify_arrays_can_be_concated_along_axis( + [cs.bindings[name] for cs in call_sites], + [], + input_concat.axis) + else: + raise NotImplementedError(type(input_concat)) + except _InvalidConcatenatability as e: + raise _InvalidConcatenatability( + f"Binding for input {name} is not concatenatable. {str(e)}") + + # verify the constraints on function bodies are satisfied + for name, output_concat in (fn_concatenatability + .output_to_concatenatability + .items()): + try: + collector(template_call.function.returns[name], + output_concat, + tuple(other_call.function.returns[name] + for other_call in other_calls)) + except _InvalidConcatenatability as e: + raise _InvalidConcatenatability( + f"Function output {name} is not concatenatable. {str(e)}") + except _InvalidConcatenatability as e: + concat_idx_to_err_msg[iconcat] = str(e) + else: + if collector.call_sites_on_hold: + raise NotImplementedError("Expressions that use part of" + " function's returned values are not" + " yet supported.") + + logger.info( + f"Found a valid concatenatability for function with ID '{fid}' --\n" + f"{fn_concatenatability}") + + yield immutabledict(collector.ary_to_concatenatability) + + log_str = ( + f"No more valid concatenatabilities for function with ID '{fid}'. " + "Unsuitable candidates:\n") + for iconcat, fn_concatenatability in enumerate(fn_concatenatabilities): + try: + err_msg = concat_idx_to_err_msg[iconcat] + except KeyError: + continue + log_str += f"Candidate:\n{fn_concatenatability}\n" + log_str += f"Error: {concat_idx_to_err_msg[iconcat]}\n\n" + logger.info(log_str) + + +def _get_replacement_map_post_concatenating( + call_sites: Sequence[Call], + used_call_results: FrozenOrderedSet(NamedCallResult), + input_concatenator: _InputConcatenator, + output_slicer: _OutputSlicer) -> Mapping[NamedCallResult, Array]: + """ + .. note:: + + We require *call_sites* to be ordered to determine the concatenation + order. + """ + assert call_sites, "Empty `call_sites`." + + ary_to_concatenatabilities = _get_ary_to_concatenatabilities(call_sites) + + template_call_site, *other_call_sites = call_sites + template_function = template_call_site.function + fid = next(iter(template_function.tags_of_type(FunctionIdentifier))) + + try: + ary_to_concatenatability = next(ary_to_concatenatabilities) + except StopIteration: + raise ValueError( + f"No valid concatenatibilities found for function with ID '{fid}'.") + else: + if __debug__: + try: + next(ary_to_concatenatabilities) + except StopIteration: + # unique concatenatibility + pass + else: + from warnings import warn + # TODO: Take some input from the user to resolve this ambiguity. + warn( + "Multiple concatenation possibilities found for function with " + f"ID '{fid}'. This may lead to non-deterministic transformed " + "expression graphs.") + + # {{{ actually perform the concatenation + + template_returns = template_function.returns + template_bindings = template_call_site.bindings + + function_concatenator = _FunctionConcatenator( + current_stack=(), input_concatenator=input_concatenator, + ary_to_concatenatability=ary_to_concatenatability) + + if __debug__: + # FIXME: We may be able to handle this without burdening the user + # See https://github.com/inducer/pytato/issues/559 + from collections import defaultdict + param_to_used_calls = defaultdict(OrderedSet) + for output_name in template_call_site.keys(): + for csite in call_sites: + call_result = csite[output_name] + if call_result in used_call_results: + ret = csite.function.returns[output_name] + used_params = ( + OrderedSet( + expr.name + for expr in InputGatherer()(ret)) + & csite.function.parameters) + for name in used_params: + param_to_used_calls[name].add(csite) + for name, used_calls in param_to_used_calls.items(): + if used_calls != OrderedSet(call_sites): + from warnings import warn + warn( + f"DAG output does not depend on parameter '{name}' for some " + f"calls to function with ID '{fid}'. Concatenation will prevent " + "these unused inputs from being removed from the DAG when the " + "function is inlined. This may lead to unnecessary computation.") + + # new_returns: concatenated function body + new_returns: Dict[str, Array] = {} + for output_name in template_call_site.keys(): + new_returns[output_name] = function_concatenator( + template_returns[output_name], + tuple(csite.function.returns[output_name] + for csite in other_call_sites)) + + # }}} + + # construct new function body + if any( + new_returns[output_name] is not template_returns[output_name] + for output_name in template_returns): + new_function = FunctionDefinition( + template_call_site.function.parameters, + template_call_site.function.return_type, + immutabledict(new_returns), + tags=template_call_site.function.tags) + else: + new_function = template_call_site.function + + result: Dict[NamedCallResult, Array] = {} + + new_call_bindings: Dict[str, Array] = {} + + # construct new bindings + for param_name in template_bindings: + param_placeholder = template_call_site.function.get_placeholder(param_name) + param_concat = ary_to_concatenatability[((), param_placeholder)] + if isinstance(param_concat, ConcatableAlongAxis): + param_bindings = tuple([ + csite.bindings[param_name] + for csite in call_sites]) + new_binding = input_concatenator( + param_bindings, + param_concat.axis) + elif isinstance(param_concat, ConcatableIfConstant): + new_binding = template_bindings[param_name] + else: + raise NotImplementedError(type(param_concat)) + new_call_bindings[param_name] = new_binding + + # construct new call + if ( + new_function is not template_call_site.function + or any( + new_call_bindings[param_name] is not template_bindings[param_name] + for param_name in template_bindings)): + new_call = Call( + function=new_function, + bindings=immutabledict(new_call_bindings), + tags=template_call_site.tags) + else: + new_call = template_call_site + + # slice into new_call's outputs to replace the old expressions. + for output_name, output_ary in (template_call_site + .function + .returns + .items()): + concat = ary_to_concatenatability[((), output_ary)] + new_return = new_call[output_name] + if isinstance(concat, ConcatableIfConstant): + # FIXME: Does it make sense to not concatenate if some arguments are + # ConcatableIfConstant and some are ConcatableAlongAxis? Seems like that + # would cause problems... + for cs in call_sites: + result[cs[output_name]] = new_return + elif isinstance(concat, ConcatableAlongAxis): + slice_sizes = [ + cs[output_name].shape[concat.axis] + for cs in call_sites] + output_slices = output_slicer(new_return, concat.axis, slice_sizes) + for cs, output_slice in zip(call_sites, output_slices): + result[cs[output_name]] = output_slice + else: + raise NotImplementedError(type(concat)) + + return immutabledict(result) + + +def concatenate_calls(expr: ArrayOrNames, + call_site_filter: Callable[[CallSiteLocation], bool], + *, + inherit_axes: bool = False, + warn_if_no_calls: bool = True, + err_if_no_calls: bool = False, + ignore_tag_types: frozenset(type) | None = None, + ) -> ArrayOrNames: + r""" + Returns a copy of *expr* after concatenating all call-sites ``C`` such that + ``call_site_filter(C) is True``. + + :arg call_site_filter: A callable to select which instances of + :class:`~pytato.function.Call`\ s must be concatenated. + """ + if ignore_tag_types is None: + ignore_tag_types: frozenset(type) = frozenset() + + call_site_collector = CallSiteDependencyCollector(stack=()) + + all_call_sites = call_site_collector(expr) + filtered_call_sites = FrozenOrderedSet(cs + for cs in all_call_sites + if call_site_filter(cs)) + + function_ids = FrozenOrderedSet( + next(iter(cs.call.function.tags_of_type(FunctionIdentifier))) + for cs in filtered_call_sites) + + # Input concatenator needs to be set up outside of the loop in order to prevent + # creating duplicates; probably not strictly necessary for output slicer + input_concatenator = _InputConcatenator(inherit_axes=inherit_axes) + output_slicer = _OutputSlicer(inherit_axes=inherit_axes) + + result = expr + + for fid in function_ids: + call_site_dep_collector = CallSiteDependencyCollector(stack=()) + call_site_dep_collector(result) + + call_site_to_dep_call_sites = \ + call_site_dep_collector.call_site_to_dep_call_sites + + unbatched_call_sites: OrderedSet[CallSiteLocation] = OrderedSet( + cs for cs in call_site_to_dep_call_sites.keys() + if call_site_filter(cs) and fid in cs.call.function.tags) + + for cs in unbatched_call_sites: + for ret in cs.call.function.returns.values(): + nested_calls = collect_nodes_of_type(ret, Call) + if nested_calls: + raise NotImplementedError( + "Concatenation of nested calls is not yet supported.") + + call_site_batches: List[FrozenOrderedSet[CallSiteLocation]] = [] + + replacement_map: Dict[ + Tuple[NamedCallResult, Tuple[Call, ...]], + Array] = {} + + used_call_results = collect_nodes_of_type(result, NamedCallResult) + + while unbatched_call_sites: + ready_call_sites = FrozenOrderedSet( + cs for cs in unbatched_call_sites + if not call_site_to_dep_call_sites[cs] & unbatched_call_sites) + + from mpi4py import MPI + rank = MPI.COMM_WORLD.rank + + # if fid.identifier == "_make_fluid_state": + # print(f"{rank}: {len(ready_call_sites)=}") + + if not ready_call_sites: + raise ValueError("Found cycle in call site dependency graph.") + + template_call_site = next(iter(ready_call_sites)) + template_fn = template_call_site.call.function + + from pytato.equality import SimilarityComparer + similarity_comparer = SimilarityComparer( + ignore_tag_types=ignore_tag_types) + # err_on_not_similar=(fid.identifier == "_make_fluid_state")) + + # if fid.identifier == "_make_fluid_state": + # for cs in ready_call_sites: + # same_outputs = ( + # frozenset(cs.call.function.returns.keys()) + # == frozenset(template_fn.returns.keys())) + # similar = all( + # similarity_comparer( + # cs.call.function.returns[name], + # template_fn.returns[name]) + # for name in template_fn.returns) + # same_stack = (cs.stack == template_call_site.stack) + # print(f"{rank}: {same_outputs=}, {similar=}, {same_stack=}") + # # if not similar: + # # for name in template_fn.returns: + # # from pytato.analysis import get_num_nodes + # # nnodes_template = get_num_nodes(template_fn.returns[name]) + # # nnodes_other = get_num_nodes(cs.call.function.returns[name]) + # # print(f"{rank}: {name=}, {nnodes_template=}, {nnodes_other=}") + + similar_call_sites = FrozenOrderedSet( + cs for cs in ready_call_sites + if ( + ( + FrozenOrderedSet(cs.call.function.returns.keys()) + == FrozenOrderedSet(template_fn.returns.keys())) + and all( + similarity_comparer( + cs.call.function.returns[name], + template_fn.returns[name]) + for name in template_fn.returns) + and cs.stack == template_call_site.stack)) + + # if fid.identifier == "_make_fluid_state": + # print(f"{rank}: {len(similar_call_sites)=}") + + if not similar_call_sites: + raise ValueError("Failed to find similar call sites to concatenate.") + + # def get_axis0_len(cs): + # first_out_name = next(iter(cs.call.function.returns.keys())) + # axis0_len = cs.call[first_out_name].shape[0] + # assert all( + # cs.call[name].shape[0] == axis0_len + # for name in cs.call.function.returns) + # return axis0_len + + # batch_call_sites = FrozenOrderedSet(sorted(similar_call_sites, key=get_axis0_len)) + batch_call_sites = similar_call_sites + + call_site_batches.append(batch_call_sites) + unbatched_call_sites -= batch_call_sites + + # FIXME: this doesn't work; need to create/execute batches one at a time, + # then repeat the steps above to collect the updated call sites after + # concatenating the previous batch + for ibatch, call_sites in enumerate(call_site_batches): + from mpi4py import MPI + rank = MPI.COMM_WORLD.rank + + template_fn = next(iter(call_sites)).call.function + + # FIXME: Can't currently call get_num_nodes on a function definition + from pytato.array import make_dict_of_named_arrays + from pytato.analysis import get_num_nodes + fn_body = make_dict_of_named_arrays(template_fn.returns) + nnodes = get_num_nodes(fn_body) + + print( + f"{rank}: Concatenating function '{fid}' (batch {ibatch+1} of " + f"{len(call_site_batches)}: {nnodes} nodes, {len(call_sites)} " + "call sites).") + + if len(call_sites) <= 1: + if err_if_no_calls: + raise ValueError( + f"Not enough calls to concatenate function with ID '{fid}'.") + elif warn_if_no_calls: + from warnings import warn + warn( + f"Not enough calls to concatenate function with ID '{fid}'.", + stacklevel=2) + else: + pass + continue + + old_expr_to_new_expr_map = _get_replacement_map_post_concatenating( + [cs.call for cs in call_sites], + used_call_results, + input_concatenator=input_concatenator, + output_slicer=output_slicer) + + stack, = FrozenOrderedSet(cs.stack for cs in call_sites) + + replacement_map.update({ + (old_expr, stack): new_expr + for old_expr, new_expr in old_expr_to_new_expr_map.items()}) + + # FIXME: Still getting some duplicated `Concatenate`s, not sure why + dedup = Deduplicator() + result = dedup(result) + replacement_map = { + old_expr_and_stack: dedup(new_expr) + for old_expr_and_stack, new_expr in replacement_map.items()} + + result = _NamedCallResultReplacerPostConcatenate( + replacement_map=replacement_map, + current_stack=())(result) + + assert isinstance(result, (Array, AbstractResultWithNamedArrays)) + return result + +# }}} + # vim:foldmethod=marker diff --git a/pytato/transform/einsum_distributive_law.py b/pytato/transform/einsum_distributive_law.py index 8cd635f61..694901b03 100644 --- a/pytato/transform/einsum_distributive_law.py +++ b/pytato/transform/einsum_distributive_law.py @@ -57,6 +57,8 @@ Stack, ) from pytato.transform import ( + ArrayOrNames, + CacheKeyT, MappedT, TransformMapperWithExtraArgs, _verify_is_array, @@ -160,6 +162,13 @@ def __init__(self, super().__init__() self.how_to_distribute = how_to_distribute + def get_cache_key( + self, + expr: ArrayOrNames, + ctx: _EinsumDistributiveLawMapperContext | None + ) -> CacheKeyT: + return (expr, ctx) + def _map_input_base(self, expr: InputArgumentBase, ctx: _EinsumDistributiveLawMapperContext | None, diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 507a450cd..2898b0cef 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -53,15 +53,18 @@ ShapeComponent, ShapeType, Stack, + _get_einsum_access_descr_to_axis_len, ) from pytato.diagnostic import CannotBeLoweredToIndexLambda from pytato.scalar_expr import INT_CLASSES, ScalarExpression from pytato.tags import AssumeNonNegative -from pytato.transform import Mapper +from pytato.transform import IndexOrShapeExpr, Mapper from pytato.utils import normalized_slice_does_not_change_axis if TYPE_CHECKING: + from collections.abc import Mapping + import numpy as np @@ -126,16 +129,14 @@ def _generate_index_expressions( for old_size_till, old_stride in zip(old_size_tills, old_strides, strict=True)) -def _get_reshaped_indices(expr: Reshape) -> tuple[ScalarExpression, ...]: +def _get_reshaped_indices( + order: str, old_shape: ShapeType, new_shape: ShapeType + ) -> tuple[ScalarExpression, ...]: - if expr.order.upper() not in ["C", "F"]: + if order.upper() not in ["C", "F"]: raise NotImplementedError("Order expected to be 'C' or 'F'", " (case insensitive). Found order = ", - f"{expr.order}") - - order = expr.order - old_shape = expr.array.shape - new_shape = expr.shape + f"{order}") # index variables need to be unique and depend on the new shape length index_vars = [prim.Variable(f"_{i}") for i in range(len(new_shape))] @@ -143,9 +144,86 @@ def _get_reshaped_indices(expr: Reshape) -> tuple[ScalarExpression, ...]: # {{{ check for scalars if old_shape == (): - assert expr.size == 1 + from pytools import product + assert product(new_shape) == 1 return () + if order not in ["C", "F"]: + raise NotImplementedError("Order expected to be 'C' or 'F'", + f" found {order}") + + non1_shape = [] + for axis_len in new_shape: + assert isinstance(axis_len, INT_CLASSES) + if axis_len > 1: + non1_shape.append(axis_len) + non1_shape = tuple(non1_shape) + + old_non1_shape = [] + for axis_len in old_shape: + assert isinstance(axis_len, INT_CLASSES) + if axis_len > 1: + old_non1_shape.append(axis_len) + old_non1_shape = tuple(old_non1_shape) + + if non1_shape == old_non1_shape: + non1_axes = tuple( + iaxis for iaxis in range(len(new_shape)) + if new_shape[iaxis] > 1) + old_non1_axes = tuple( + iaxis for iaxis in range(len(old_shape)) + if old_shape[iaxis] > 1) + old_iaxis_to_iaxis = { + old_iaxis: iaxis + for old_iaxis, iaxis in zip( + old_non1_axes, non1_axes)} + return tuple( + prim.Variable(f"_{old_iaxis_to_iaxis[old_iaxis]}") + if old_iaxis in old_iaxis_to_iaxis + else 0 + for old_iaxis in range(len(old_shape))) + + if order == "C": + newstrides: list[IntegralT] = [1] # reshaped array strides + for new_axis_len in reversed(new_shape[1:]): + assert isinstance(new_axis_len, INT_CLASSES) + newstrides.insert(0, newstrides[0]*new_axis_len) + + flattened_idx = sum(prim.Variable(f"_{i}")*stride + for i, stride in enumerate(newstrides)) + + oldstrides: list[IntegralT] = [1] # input array strides + for axis_len in reversed(old_shape[1:]): + assert isinstance(axis_len, INT_CLASSES) + oldstrides.insert(0, oldstrides[0]*axis_len) + + assert isinstance(old_shape[-1], INT_CLASSES) + oldsizetills = [old_shape[-1]] # input array size + # till for axes idx + for old_axis_len in reversed(old_shape[:-1]): + assert isinstance(old_axis_len, INT_CLASSES) + oldsizetills.insert(0, oldsizetills[0]*old_axis_len) + + else: + newstrides: list[IntegralT] = [1] # reshaped array strides + for new_axis_len in new_shape[:-1]: + assert isinstance(new_axis_len, INT_CLASSES) + newstrides.append(newstrides[-1]*new_axis_len) + + flattened_idx = sum(prim.Variable(f"_{i}")*stride + for i, stride in enumerate(newstrides)) + + oldstrides: list[IntegralT] = [1] # input array strides + for axis_len in old_shape[:-1]: + assert isinstance(axis_len, INT_CLASSES) + oldstrides.append(oldstrides[-1]*axis_len) + + assert isinstance(old_shape[0], INT_CLASSES) + oldsizetills = [old_shape[0]] # input array size till for axes idx + for old_axis_len in old_shape[1:]: + assert isinstance(old_axis_len, INT_CLASSES) + oldsizetills.append(oldsizetills[-1]*old_axis_len) + if new_shape == (): return _generate_index_expressions(old_shape, new_shape, order, index_vars) @@ -256,10 +334,17 @@ def _get_reshaped_indices(expr: Reshape) -> tuple[ScalarExpression, ...]: class ToIndexLambdaMixin: - def _rec_shape(self, shape: ShapeType) -> ShapeType: - return tuple(self.rec(s) if isinstance(s, Array) - else s - for s in shape) + def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...] + ) -> tuple[IndexOrShapeExpr, ...]: + # type-ignore-reason: apparently mypy cannot substitute typevars + # here. + new_situp = tuple( + self.rec(s) if isinstance(s, Array) else s + for s in situp) + if all(new_s is s for s, new_s in zip(situp, new_situp, strict=True)): + return situp + else: + return new_situp # type: ignore[return-value] if TYPE_CHECKING: def rec( @@ -270,17 +355,27 @@ def rec( return super().rec( # type: ignore[no-any-return,misc] expr, *args, **kwargs) - def map_index_lambda(self, expr: IndexLambda) -> IndexLambda: - return IndexLambda(expr=expr.expr, - shape=self._rec_shape(expr.shape), - dtype=expr.dtype, - bindings=immutabledict({name: self.rec(bnd) - for name, bnd - in sorted(expr.bindings.items())}), - axes=expr.axes, - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + def map_index_lambda(self, expr: IndexLambda) -> Array: + new_shape = self.rec_idx_or_size_tuple(expr.shape) + new_bindings: Mapping[str, Array] = immutabledict({ + name: self.rec(subexpr) + for name, subexpr in sorted(expr.bindings.items())}) + if ( + new_shape is expr.shape + and frozenset(new_bindings.keys()) == frozenset(expr.bindings.keys()) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings)): + return expr + else: + return IndexLambda(expr=expr.expr, + shape=new_shape, + dtype=expr.dtype, + bindings=new_bindings, + axes=expr.axes, + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack) -> IndexLambda: subscript = tuple(prim.Variable(f"_{i}") @@ -305,11 +400,11 @@ def map_stack(self, expr: Stack) -> IndexLambda: subarray_expr, stack_expr) - bindings = {f"_in{i}": self.rec(array) - for i, array in enumerate(expr.arrays)} + bindings = {f"_in{i}": self.rec(ary) + for i, ary in enumerate(expr.arrays)} return IndexLambda(expr=stack_expr, - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, axes=expr.axes, bindings=immutabledict(bindings), @@ -328,10 +423,12 @@ def get_subscript(array_index: int, offset: ScalarExpression) -> Subscript: for i in range(len(expr.shape))] return Subscript(aggregate, tuple(index)) + rec_arrays: tuple[Array, ...] = tuple(self.rec(ary) for ary in expr.arrays) + lbounds: list[Any] = [0] - ubounds: list[Any] = [expr.arrays[0].shape[expr.axis]] + ubounds: list[Any] = [rec_arrays[0].shape[expr.axis]] - for i, array in enumerate(expr.arrays[1:], start=1): + for i, array in enumerate(rec_arrays[1:], start=1): ubounds.append(ubounds[i-1]+array.shape[expr.axis]) lbounds.append(ubounds[i-1]) @@ -354,11 +451,11 @@ def get_subscript(array_index: int, offset: ScalarExpression) -> Subscript: subarray_expr, concat_expr) - bindings = {f"_in{i}": self.rec(array) - for i, array in enumerate(expr.arrays)} + bindings = {f"_in{i}": ary + for i, ary in enumerate(rec_arrays)} return IndexLambda(expr=concat_expr, - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, bindings=immutabledict(bindings), axes=expr.axes, @@ -377,7 +474,9 @@ def map_einsum(self, expr: Einsum) -> IndexLambda: dim_to_index_lambda_components, ) - bindings = {f"_in{k}": self.rec(arg) for k, arg in enumerate(expr.args)} + rec_args: tuple[Array, ...] = tuple(self.rec(arg) for arg in expr.args) + + bindings = {f"_in{k}": arg for k, arg in enumerate(rec_args)} redn_bounds: dict[str, tuple[ScalarExpression, ScalarExpression]] = {} args_as_pym_expr: list[prim.Subscript] = [] namegen = UniqueNameGenerator(set(bindings)) @@ -385,13 +484,16 @@ def map_einsum(self, expr: Einsum) -> IndexLambda: # {{{ add bindings coming from the shape expressions + access_descr_to_axis_len = _get_einsum_access_descr_to_axis_len( + expr.access_descriptors, rec_args) + for access_descr, (iarg, arg) in zip(expr.access_descriptors, - enumerate(expr.args), strict=True): + enumerate(rec_args), strict=True): subscript_indices: list[ArithmeticExpression] = [] for iaxis, axis in enumerate(access_descr): if not are_shape_components_equal( arg.shape[iaxis], - expr._access_descr_to_axis_len()[axis]): + access_descr_to_axis_len[axis]): # axis is broadcasted assert are_shape_components_equal(arg.shape[iaxis], 1) subscript_indices.append(0) @@ -432,7 +534,7 @@ def map_einsum(self, expr: Einsum) -> IndexLambda: immutabledict(redn_bounds)) return IndexLambda(expr=inner_expr, - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, bindings=immutabledict(bindings), axes=expr.axes, @@ -443,12 +545,14 @@ def map_einsum(self, expr: Einsum) -> IndexLambda: def map_roll(self, expr: Roll) -> IndexLambda: from pytato.utils import dim_to_index_lambda_components + rec_array = self.rec(expr.array) + index_expr: prim.ExpressionNode = prim.Variable("_in0") indices: list[ArithmeticExpression] = [ prim.Variable(f"_{d}") for d in range(expr.ndim)] axis = expr.axis axis_len_expr, bindings = dim_to_index_lambda_components( - expr.shape[axis], + rec_array.shape[axis], UniqueNameGenerator({"_in0"})) # Mypy has a point: the type system does not prove that the operands are @@ -459,13 +563,12 @@ def map_roll(self, expr: Roll) -> IndexLambda: index_expr = index_expr[tuple(indices)] # type-ignore-reason: `bindings` was returned as Dict[str, SizeParam] - bindings["_in0"] = expr.array # type: ignore[assignment] + bindings["_in0"] = rec_array # type: ignore[assignment] return IndexLambda(expr=index_expr, - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, - bindings=immutabledict({name: self.rec(bnd) - for name, bnd in bindings.items()}), + bindings=immutabledict(bindings), axes=expr.axes, var_to_reduction_descr=immutabledict(), tags=expr.tags, @@ -476,27 +579,30 @@ def map_contiguous_advanced_index(self, ) -> IndexLambda: from pytato.utils import get_indexing_expression, get_shape_after_broadcasting + rec_array = self.rec(expr.array) + rec_indices = self.rec_idx_or_size_tuple(expr.indices) + i_adv_indices = tuple(i - for i, idx_expr in enumerate(expr.indices) + for i, idx_expr in enumerate(rec_indices) if isinstance(idx_expr, (Array, *INT_CLASSES))) adv_idx_shape = get_shape_after_broadcasting([ - cast("Array | int | np.integer[Any]", expr.indices[i_idx]) + cast("Array | int | np.integer[Any]", rec_indices[i_idx]) for i_idx in i_adv_indices]) vng = UniqueNameGenerator() indices: list[ArithmeticExpression] = [] in_ary = vng("in") - bindings = {in_ary: self.rec(expr.array)} + bindings = {in_ary: rec_array} islice_idx = 0 for i_idx, (idx, axis_len) in enumerate( - zip(expr.indices, expr.array.shape, strict=True)): + zip(rec_indices, rec_array.shape, strict=True)): if isinstance(idx, INT_CLASSES): if isinstance(axis_len, INT_CLASSES): indices.append(idx % axis_len) else: bnd_name = vng("in") - bindings[bnd_name] = self.rec(axis_len) + bindings[bnd_name] = axis_len indices.append(idx % prim.Variable(bnd_name)) elif isinstance(idx, NormalizedSlice): if normalized_slice_does_not_change_axis(idx, axis_len): @@ -508,7 +614,7 @@ def map_contiguous_advanced_index(self, elif isinstance(idx, Array): if isinstance(axis_len, INT_CLASSES): bnd_name = vng("in") - bindings[bnd_name] = self.rec(idx) + bindings[bnd_name] = idx indirect_idx_expr: ArithmeticExpression = prim.Subscript( prim.Variable(bnd_name), get_indexing_expression( @@ -536,7 +642,7 @@ def map_contiguous_advanced_index(self, return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), bindings=immutabledict(bindings), - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, axes=expr.axes, var_to_reduction_descr=immutabledict(), @@ -547,28 +653,32 @@ def map_contiguous_advanced_index(self, def map_non_contiguous_advanced_index( self, expr: AdvancedIndexInNoncontiguousAxes) -> IndexLambda: from pytato.utils import get_indexing_expression, get_shape_after_broadcasting + + rec_array = self.rec(expr.array) + rec_indices = self.rec_idx_or_size_tuple(expr.indices) + i_adv_indices = tuple(i - for i, idx_expr in enumerate(expr.indices) + for i, idx_expr in enumerate(rec_indices) if isinstance(idx_expr, (Array, *INT_CLASSES))) adv_idx_shape = get_shape_after_broadcasting([ - cast("Array | int | np.integer[Any]", expr.indices[i_idx]) + cast("Array | int | np.integer[Any]", rec_indices[i_idx]) for i_idx in i_adv_indices]) vng = UniqueNameGenerator() indices: list[ArithmeticExpression] = [] in_ary = vng("in") - bindings = {in_ary: self.rec(expr.array)} + bindings = {in_ary: rec_array} islice_idx = len(adv_idx_shape) - for idx, axis_len in zip(expr.indices, expr.array.shape, strict=True): + for idx, axis_len in zip(rec_indices, rec_array.shape, strict=True): if isinstance(idx, INT_CLASSES): if isinstance(axis_len, INT_CLASSES): indices.append(idx % axis_len) else: bnd_name = vng("in") - bindings[bnd_name] = self.rec(axis_len) + bindings[bnd_name] = axis_len indices.append(idx % prim.Variable(bnd_name)) elif isinstance(idx, NormalizedSlice): if normalized_slice_does_not_change_axis(idx, axis_len): @@ -580,7 +690,7 @@ def map_non_contiguous_advanced_index( elif isinstance(idx, Array): if isinstance(axis_len, INT_CLASSES): bnd_name = vng("in") - bindings[bnd_name] = self.rec(idx) + bindings[bnd_name] = idx indirect_idx_expr: ArithmeticExpression = prim.Subscript( prim.Variable(bnd_name), @@ -605,7 +715,7 @@ def map_non_contiguous_advanced_index( return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), bindings=immutabledict(bindings), - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, axes=expr.axes, var_to_reduction_descr=immutabledict(), @@ -614,20 +724,23 @@ def map_non_contiguous_advanced_index( ) def map_basic_index(self, expr: BasicIndex) -> IndexLambda: + rec_array = self.rec(expr.array) + rec_indices = self.rec_idx_or_size_tuple(expr.indices) + vng = UniqueNameGenerator() indices: list[ArithmeticExpression] = [] in_ary = vng("in") - bindings = {in_ary: self.rec(expr.array)} + bindings = {in_ary: rec_array} islice_idx = 0 - for idx, axis_len in zip(expr.indices, expr.array.shape, strict=True): + for idx, axis_len in zip(rec_indices, rec_array.shape, strict=True): if isinstance(idx, INT_CLASSES): if isinstance(axis_len, INT_CLASSES): indices.append(idx % axis_len) else: bnd_name = vng("in") - bindings[bnd_name] = self.rec(axis_len) + bindings[bnd_name] = axis_len indices.append(idx % prim.Variable(bnd_name)) elif isinstance(idx, NormalizedSlice): if normalized_slice_does_not_change_axis(idx, axis_len): @@ -642,7 +755,7 @@ def map_basic_index(self, expr: BasicIndex) -> IndexLambda: return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), bindings=immutabledict(bindings), - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, axes=expr.axes, var_to_reduction_descr=immutabledict(), @@ -651,18 +764,22 @@ def map_basic_index(self, expr: BasicIndex) -> IndexLambda: ) def map_reshape(self, expr: Reshape) -> IndexLambda: - indices = _get_reshaped_indices(expr) + rec_array = self.rec(expr.array) + rec_newshape = self.rec_idx_or_size_tuple(expr.shape) + indices = _get_reshaped_indices(expr.order, rec_array.shape, rec_newshape) index_expr = prim.Variable("_in0")[tuple(indices)] return IndexLambda(expr=index_expr, - shape=self._rec_shape(expr.shape), + shape=rec_newshape, dtype=expr.dtype, - bindings=immutabledict({"_in0": self.rec(expr.array)}), + bindings=immutabledict({"_in0": rec_array}), axes=expr.axes, var_to_reduction_descr=immutabledict(), tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation) -> IndexLambda: + rec_array = self.rec(expr.array) + indices: list[ArithmeticExpression | None] = [None] * expr.ndim for from_index, to_index in enumerate(expr.axis_permutation): indices[to_index] = prim.Variable(f"_{from_index}") @@ -671,9 +788,9 @@ def map_axis_permutation(self, expr: AxisPermutation) -> IndexLambda: cast("tuple[ArithmeticExpression]", tuple(indices))] return IndexLambda(expr=index_expr, - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, - bindings=immutabledict({"_in0": self.rec(expr.array)}), + bindings=immutabledict({"_in0": rec_array}), axes=expr.axes, var_to_reduction_descr=immutabledict(), tags=expr.tags, diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index d50da22e0..e3e880cf7 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -70,6 +70,7 @@ IndexLambda, InputArgumentBase, NamedArray, + NormalizedSlice, Reshape, Stack, ) @@ -79,8 +80,10 @@ IDX_LAMBDA_AXIS_INDEX, CombineMapper, ) +from pytato.tags import UseInputAxis from pytato.transform import ArrayOrNames, CopyMapper, Mapper, TransformMapperCache from pytato.transform.lower_to_index_lambda import to_index_lambda +from pytato.utils import are_shape_components_equal logger = logging.getLogger(__name__) @@ -326,7 +329,26 @@ def map_stack(self, expr: Stack) -> None: def map_concatenate(self, expr: Concatenate) -> None: for ary in expr.arrays: self.rec(ary) - self.add_equations_using_index_lambda_version_of_expr(expr) + # FIXME: Figure out how to integrate the UseInputAxis stuff into + # add_equations_using_index_lambda_version_of_expr + # self.add_equations_using_index_lambda_version_of_expr(expr) + for ary in expr.arrays: + assert ary.ndim == expr.ndim + for iaxis in range(expr.ndim): + if iaxis == expr.axis: + use_input_axis_tags = expr.axes[iaxis].tags_of_type( + UseInputAxis) + if use_input_axis_tags: + tag, = use_input_axis_tags + self.record_equation( + self.get_var_for_axis(expr.arrays[tag.key], tag.axis), + self.get_var_for_axis(expr, iaxis)) + else: + # non-concatenated axes share the dimensions. + self.record_equation( + self.get_var_for_axis(ary, iaxis), + self.get_var_for_axis(expr, iaxis) + ) def map_axis_permutation(self, expr: AxisPermutation ) -> None: @@ -335,7 +357,36 @@ def map_axis_permutation(self, expr: AxisPermutation def map_basic_index(self, expr: BasicIndex) -> None: self.rec(expr.array) - self.add_equations_using_index_lambda_version_of_expr(expr) + # FIXME: Figure out how to integrate the UseInputAxis stuff into + # add_equations_using_index_lambda_version_of_expr + # self.add_equations_using_index_lambda_version_of_expr(expr) + i_out_axis = 0 + + assert len(expr.indices) == expr.array.ndim + + for i_in_axis, idx in enumerate(expr.indices): + if isinstance(idx, int): + pass + else: + assert isinstance(idx, NormalizedSlice) + use_input_axis_tags = expr.axes[i_out_axis].tags_of_type( + UseInputAxis) + if use_input_axis_tags: + tag, = use_input_axis_tags + self.record_equation( + self.get_var_for_axis(expr.array, tag.axis), + self.get_var_for_axis(expr, i_out_axis)) + elif (idx.step == 1 + and are_shape_components_equal(idx.start, 0) + and are_shape_components_equal(idx.stop, + expr.array.shape[i_in_axis])): + + self.record_equation( + self.get_var_for_axis(expr.array, i_in_axis), + self.get_var_for_axis(expr, i_out_axis) + ) + + i_out_axis += 1 def map_contiguous_advanced_index(self, expr: AdvancedIndexInContiguousAxes @@ -416,9 +467,9 @@ class AxisTagAttacher(CopyMapper): def __init__(self, axis_to_tags: Mapping[tuple[Array, int | str], Collection[Tag]], tag_corresponding_redn_descr: bool, - _cache: TransformMapperCache[ArrayOrNames] | None = None, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition] | None = None): + TransformMapperCache[FunctionDefinition, []] | None = None): super().__init__(_cache=_cache, _function_cache=_function_cache) self.axis_to_tags: Mapping[tuple[Array, int | str], Collection[Tag]] = axis_to_tags @@ -465,9 +516,9 @@ def _attach_tags(self, expr: Array, rec_expr: Array) -> Array: return result def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache_retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, @@ -478,7 +529,7 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: assert isinstance(expr, Array) # type-ignore reason: passed "ArrayOrNames"; expected "Array" result = self._attach_tags(expr, result) # type: ignore[arg-type] - return self._cache.add(expr, result, key=key) + return self._cache_add(inputs, result) def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError( diff --git a/pytato/transform/remove_broadcasts_einsum.py b/pytato/transform/remove_broadcasts_einsum.py index 2d8f7e0f0..50ee4967c 100644 --- a/pytato/transform/remove_broadcasts_einsum.py +++ b/pytato/transform/remove_broadcasts_einsum.py @@ -28,46 +28,118 @@ THE SOFTWARE. """ -from typing import cast +from typing import TYPE_CHECKING, cast from pytato.array import Array, Einsum, EinsumAxisDescriptor -from pytato.transform import CopyMapper, MappedT, _verify_is_array +from pytato.transform import ( + ArrayOrNames, + CacheKeyT, + CopyMapperWithExtraArgs, + MappedT, + Mapper, + _verify_is_array, +) from pytato.utils import are_shape_components_equal -class EinsumWithNoBroadcastsRewriter(CopyMapper): - def map_einsum(self, expr: Einsum) -> Array: +if TYPE_CHECKING: + from pytato.function import FunctionDefinition + + +class EinsumWithNoBroadcastsRewriter(CopyMapperWithExtraArgs[[tuple[int, ...] | None]]): + def get_cache_key( + self, + expr: ArrayOrNames, + axes_to_squeeze: tuple[int, ...] | None = None + ) -> CacheKeyT: + return (expr, axes_to_squeeze) + + def get_function_definition_cache_key( + self, + expr: FunctionDefinition, + axes_to_squeeze: tuple[int, ...] | None = None + ) -> CacheKeyT: + assert axes_to_squeeze is None + return expr + + def _squeeze_axes( + self, + expr: Array, + axes_to_squeeze: tuple[int, ...] | None = None) -> Array: + result = ( + expr[ + tuple( + slice(None) if idim not in axes_to_squeeze else 0 + for idim in range(expr.ndim))] + if axes_to_squeeze else expr) + return result + + def rec( + self, + expr: ArrayOrNames, + axes_to_squeeze: tuple[int, ...] | None = None) -> ArrayOrNames: + inputs = self._make_cache_inputs(expr, axes_to_squeeze) + try: + return self._cache_retrieve(inputs) + except KeyError: + rec_result: ArrayOrNames = Mapper.rec(self, expr, None) + result: ArrayOrNames + if isinstance(expr, Array): + result = self._squeeze_axes( + _verify_is_array(rec_result), + axes_to_squeeze) + else: + result = rec_result + return self._cache_add(inputs, result) + + def map_einsum( + self, expr: Einsum, axes_to_squeeze: tuple[int, ...] | None) -> Array: new_args: list[Array] = [] new_access_descriptors: list[tuple[EinsumAxisDescriptor, ...]] = [] descr_to_axis_len = expr._access_descr_to_axis_len() - for acc_descrs, arg in zip(expr.access_descriptors, expr.args, strict=True): - arg = _verify_is_array(self.rec(arg)) - axes_to_squeeze: list[int] = [] + for arg, acc_descrs in zip(expr.args, expr.access_descriptors, strict=True): + axes_to_squeeze_list: list[int] = [] for idim, acc_descr in enumerate(acc_descrs): if not are_shape_components_equal(arg.shape[idim], descr_to_axis_len[acc_descr]): assert are_shape_components_equal(arg.shape[idim], 1) - axes_to_squeeze.append(idim) + axes_to_squeeze_list.append(idim) + axes_to_squeeze = tuple(axes_to_squeeze_list) if axes_to_squeeze: - arg = arg[tuple(slice(None) if idim not in axes_to_squeeze else 0 - for idim in range(arg.ndim))] - acc_descrs = tuple(acc_descr + new_arg = _verify_is_array(self.rec(arg, axes_to_squeeze)) + new_acc_descrs = tuple(acc_descr for idim, acc_descr in enumerate(acc_descrs) if idim not in axes_to_squeeze) + else: + new_arg = _verify_is_array(self.rec(arg)) + new_acc_descrs = acc_descrs - new_args.append(arg) - new_access_descriptors.append(acc_descrs) + new_args.append(new_arg) + new_access_descriptors.append(new_acc_descrs) assert len(new_args) == len(expr.args) assert len(new_access_descriptors) == len(expr.access_descriptors) - return Einsum(tuple(new_access_descriptors), - tuple(new_args), - expr.redn_axis_to_redn_descr, - tags=expr.tags, - axes=expr.axes,) + if ( + all( + new_arg is arg + for arg, new_arg in zip(expr.args, new_args, strict=True)) + and all( + new_acc_descr is acc_descr + for acc_descr, new_acc_descr in zip( + expr.access_descriptors, + new_access_descriptors, + strict=True))): + return expr + else: + return Einsum(tuple(new_access_descriptors), + tuple(new_args), + axes=expr.axes, + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def rewrite_einsums_with_no_broadcasts(expr: MappedT) -> MappedT: @@ -97,6 +169,6 @@ def rewrite_einsums_with_no_broadcasts(expr: MappedT) -> MappedT: alter its value. """ mapper = EinsumWithNoBroadcastsRewriter() - return cast("MappedT", mapper(expr)) + return cast("MappedT", mapper(expr, None)) # vim:fdm=marker diff --git a/pytato/utils.py b/pytato/utils.py index 31247897d..0af2b79da 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -31,8 +31,24 @@ ) import islpy as isl -import numpy as np +from typing import (Tuple, List, Union, Callable, Any, Sequence, Dict, + Optional, Iterable, TypeVar, FrozenSet) +from pytato.array import (Array, ShapeType, IndexLambda, SizeParam, ShapeComponent, + DtypeOrScalar, ArrayOrScalar, BasicIndex, + AdvancedIndexInContiguousAxes, + AdvancedIndexInNoncontiguousAxes, + ConvertibleToIndexExpr, IndexExpr, NormalizedSlice, + _dtype_any, Einsum) +#from pytato.scalar_expr import (ScalarExpression, IntegralScalarExpression, +# SCALAR_CLASSES, INT_CLASSES, BoolT) +from pytato.scalar_expr import (ScalarExpression, IntegralScalarExpression, + SCALAR_CLASSES, INT_CLASSES) +from pytools import UniqueNameGenerator +from pytato.transform import Mapper +from pytools.tag import Tag from immutabledict import immutabledict +import numpy as np +import islpy as isl import pymbolic.primitives as prim from pymbolic import ArithmeticExpression, Bool, Scalar @@ -172,6 +188,19 @@ def with_indices_for_broadcasted_shape(val: prim.Variable, shape: ShapeType, return val[get_indexing_expression(shape, result_shape)] +def extract_dtypes_or_scalars( + exprs: Sequence[ArrayOrScalar]) -> List[DtypeOrScalar]: + dtypes: List[DtypeOrScalar] = [] + for expr in exprs: + if isinstance(expr, Array): + dtypes.append(expr.dtype) + else: + assert isinstance(expr, SCALAR_CLASSES) + dtypes.append(expr) + + return dtypes + + def update_bindings_and_get_broadcasted_expr(arr: ArrayOrScalar, bnd_name: str, bindings: dict[str, Array], @@ -210,16 +239,19 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, ) -> ArrayOrScalar: from pytato.array import _get_default_axes + if isinstance(a1, SCALAR_CLASSES): + a1 = np.dtype(type(a1)).type(a1) + + if isinstance(a2, SCALAR_CLASSES): + a2 = np.dtype(type(a2)).type(a2) + if np.isscalar(a1) and np.isscalar(a2): from pytato.scalar_expr import evaluate return evaluate(op(a1, a2)) # type: ignore result_shape = get_shape_after_broadcasting([a1, a2]) - - # Note: get_result_type calls np.result_type by default, which means - # that we are passing a pytato array to numpy. Luckily, np.result_type - # only looks at the dtype of input arrays as of numpy v2.1. - result_dtype = get_result_type(a1, a2) + dtypes = extract_dtypes_or_scalars([a1, a2]) + result_dtype = get_result_type(*dtypes) bindings: dict[str, Array] = {} @@ -340,8 +372,10 @@ def are_shape_components_equal( if isinstance(dim1, INT_CLASSES) and isinstance(dim2, INT_CLASSES): return dim1 == dim2 + from pytato.transform import Deduplicator dim1_minus_dim2 = dim1 - dim2 assert isinstance(dim1_minus_dim2, Array) + dim1_minus_dim2 = Deduplicator()(dim1_minus_dim2) from pytato.transform import InputGatherer inputs = InputGatherer()(dim1_minus_dim2) diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index c0c3e7945..7420d1708 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -178,9 +178,10 @@ def stringify_shape(shape: ShapeType) -> str: return "(" + ", ".join(components) + ")" +# FIXME: Make this inherit from CachedWalkMapper instead? class ArrayToDotNodeInfoMapper(CachedMapper[None, None, []]): def __init__(self) -> None: - super().__init__() + super().__init__(err_on_collision=False) self.node_to_dot: dict[ArrayOrNames, _DotNodeInfo] = {} self.functions: set[FunctionDefinition] = set() diff --git a/test/test_apps.py b/test/test_apps.py index f39be848c..bdb3afc14 100644 --- a/test/test_apps.py +++ b/test/test_apps.py @@ -39,7 +39,7 @@ from pytools.tag import Tag, tag_dataclass import pytato as pt -from pytato.transform import CopyMapper, WalkMapper +from pytato.transform import CopyMapper, Deduplicator, WalkMapper # {{{ Trace an FFT @@ -78,40 +78,21 @@ def map_constant(self, expr): class FFTRealizationMapper(CopyMapper): - def __init__(self, fft_vec_gatherer): - super().__init__() - - self.fft_vec_gatherer = fft_vec_gatherer - - self.old_array_to_new_array = {} - levels = sorted(fft_vec_gatherer.level_to_arrays, reverse=True) - - lev = 0 - arrays = fft_vec_gatherer.level_to_arrays[lev] - self.finalized = False - - for lev in levels: - arrays = fft_vec_gatherer.level_to_arrays[lev] - rec_arrays = [self.rec(ary) for ary in arrays] - # reset cache so that the partial subs are not stored - self._cache.clear() - lev_array = pt.concatenate(rec_arrays, axis=0) - assert lev_array.shape == (fft_vec_gatherer.n,) - - startidx = 0 - for array in arrays: - size = array.shape[0] - sub_array = lev_array[startidx:startidx+size] - startidx += size - self.old_array_to_new_array[array] = sub_array - - assert startidx == fft_vec_gatherer.n - self.finalized = True + def __init__(self, old_array_to_new_array): + # Must use err_on_created_duplicate=False, because the use of ConstantSizer + # in map_index_lambda creates IndexLambdas that differ only in the type of + # their contained constants, which changes their identity but not their + # equality + super().__init__(err_on_created_duplicate=False) + self.old_array_to_new_array = old_array_to_new_array def map_index_lambda(self, expr): tags = expr.tags_of_type(FFTIntermediate) - if tags and (self.finalized or expr in self.old_array_to_new_array): - return self.old_array_to_new_array[expr] + if tags: + try: + return self.old_array_to_new_array[expr] + except KeyError: + pass return super().map_index_lambda( expr.copy(expr=ConstantSizer()(expr.expr))) @@ -122,6 +103,29 @@ def map_concatenate(self, expr): (ImplStored(), PrefixNamed("concat"))) +def make_fft_realization_mapper(fft_vec_gatherer): + old_array_to_new_array = {} + levels = sorted(fft_vec_gatherer.level_to_arrays, reverse=True) + + for lev in levels: + lev_mapper = FFTRealizationMapper(old_array_to_new_array) + arrays = fft_vec_gatherer.level_to_arrays[lev] + rec_arrays = [lev_mapper(ary) for ary in arrays] + lev_array = pt.concatenate(rec_arrays, axis=0) + assert lev_array.shape == (fft_vec_gatherer.n,) + + startidx = 0 + for array in arrays: + size = array.shape[0] + sub_array = lev_array[startidx:startidx+size] + startidx += size + old_array_to_new_array[array] = sub_array + + assert startidx == fft_vec_gatherer.n + + return FFTRealizationMapper(old_array_to_new_array) + + def test_trace_fft(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) @@ -134,10 +138,11 @@ def test_trace_fft(ctx_factory): wrap_intermediate_with_level=( lambda level, ary: ary.tagged(FFTIntermediate(level)))) + result = Deduplicator()(result) fft_vec_gatherer = FFTVectorGatherer(n) fft_vec_gatherer(result) - mapper = FFTRealizationMapper(fft_vec_gatherer) + mapper = make_fft_realization_mapper(fft_vec_gatherer) result = mapper(result) diff --git a/test/test_codegen.py b/test/test_codegen.py index 0c6972cf6..83be983d8 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -317,6 +317,7 @@ def test_scalar_array_binary_arith(ctx_factory, which, reverse): out_ref = np_op(x_in, y_orig.astype(dtype)) assert out.dtype == out_ref.dtype, (out.dtype, out_ref.dtype) + # In some cases ops are done in float32 in loopy but float64 in numpy. assert np.allclose(out, out_ref), (out, out_ref) @@ -926,7 +927,7 @@ def _get_x_shape(_m, n_): x = pt.make_data_wrapper(x_in, shape=_get_x_shape(m, n)) np_out = np.einsum("ij, j -> i", A_in, x_in) - pt_expr = pt.einsum("ij, j -> i", A, x) + pt_expr = pt.transform.Deduplicator()(pt.einsum("ij, j -> i", A, x)) _, (pt_out,) = pt.generate_loopy(pt_expr)(cq, m=m_in, n=n_in) @@ -1582,8 +1583,9 @@ def get_np_input_args(): np_inputs = get_np_input_args() np_result = kernel(np, **np_inputs) - pt_dag = kernel(pt, **{kw: pt.make_data_wrapper(arg) - for kw, arg in np_inputs.items()}) + pt_dag = pt.transform.Deduplicator()( + kernel(pt, **{kw: pt.make_data_wrapper(arg) + for kw, arg in np_inputs.items()})) knl = pt.generate_loopy(pt_dag, options=lp.Options(write_code=True)) @@ -1621,7 +1623,8 @@ def test_zero_size_cl_array_dedup(ctx_factory): dedup_dw_out, count_duplicates=True) # 'x2' would be merged with 'x1' as both of them point to the same data # 'x3' would be merged with 'x4' as both of them point to the same data - assert num_nodes_new == (num_nodes_old - 2) + # '2*x2' would be merged with '2*x1' as they are identical expressions + assert num_nodes_new == (num_nodes_old - 3) # {{{ test_deterministic_codegen @@ -1938,10 +1941,12 @@ def build_expression(tracer): "baz": 65 * twice_x, "quux": 7 * twice_x_2} - result_with_functions = pt.tag_all_calls_to_be_inlined( - pt.make_dict_of_named_arrays(build_expression(pt.trace_call))) - result_without_functions = pt.make_dict_of_named_arrays( - build_expression(lambda fn, *args: fn(*args))) + expr = pt.transform.Deduplicator()( + pt.make_dict_of_named_arrays(build_expression(pt.trace_call))) + + result_with_functions = pt.tag_all_calls_to_be_inlined(expr) + result_without_functions = pt.transform.Deduplicator()( + pt.make_dict_of_named_arrays(build_expression(lambda fn, *args: fn(*args)))) # test that visualizing graphs with functions works dot = pt.get_dot_graph(result_with_functions) @@ -1958,7 +1963,8 @@ def build_expression(tracer): np.testing.assert_allclose(outputs[key], expected[key]) -def test_nested_function_calls(ctx_factory): +@pytest.mark.parametrize("should_concatenate_bar", (False, True)) +def test_nested_function_calls(ctx_factory, should_concatenate_bar): from functools import partial ctx = ctx_factory() @@ -1991,7 +1997,16 @@ def call_bar(tracer, x, y): result = pt.make_dict_of_named_arrays({"out1": call_bar(pt.trace_call, x1, y1), "out2": call_bar(pt.trace_call, x2, y2)} ) + result = pt.transform.Deduplicator()(result) result = pt.tag_all_calls_to_be_inlined(result) + if should_concatenate_bar: + from pytato.transform.calls import CallSiteDependencyCollector + assert len(CallSiteDependencyCollector(())(result)) == 4 + result = pt.concatenate_calls( + result, + lambda x: pt.tags.FunctionIdentifier("bar") in x.call.function.tags) + assert len(CallSiteDependencyCollector(())(result)) == 2 + expect = pt.make_dict_of_named_arrays({"out1": call_bar(ref_tracer, x1, y1), "out2": call_bar(ref_tracer, x2, y2)} ) @@ -2063,6 +2078,111 @@ def test_pow_arg_casting(ctx_factory): (float, np.float32, np.float64) +def test_concatenate_calls_no_nested(ctx_factory): + rng = np.random.default_rng(0) + + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + def foo(x, y): + return 3*x + 4*y + 42*pt.sin(x) + 1729*pt.tan(y)*pt.maximum(x, y) + + x1 = pt.make_placeholder("x1", (10, 4), np.float64) + x2 = pt.make_placeholder("x2", (10, 4), np.float64) + + y1 = pt.make_placeholder("y1", (10, 4), np.float64) + y2 = pt.make_placeholder("y2", (10, 4), np.float64) + + z1 = pt.make_placeholder("z1", (10, 4), np.float64) + z2 = pt.make_placeholder("z2", (10, 4), np.float64) + + result = pt.make_dict_of_named_arrays({"out1": 2*pt.trace_call(foo, 2*x1, 3*x2), + "out2": 4*pt.trace_call(foo, 4*y1, 9*y2), + "out3": 6*pt.trace_call(foo, 7*z1, 8*z2) + }) + result = pt.transform.Deduplicator()(result) + + concatenated_result = pt.concatenate_calls( + result, lambda x: pt.tags.FunctionIdentifier("foo") in x.call.function.tags) + + result = pt.tag_all_calls_to_be_inlined(result) + concatenated_result = pt.tag_all_calls_to_be_inlined(concatenated_result) + + assert (pt.analysis.get_num_nodes(pt.inline_calls(result)) + > pt.analysis.get_num_nodes(pt.inline_calls(concatenated_result))) + + x1_np, x2_np, y1_np, y2_np, z1_np, z2_np = rng.random((6, 10, 4)) + + _, out_dict1 = pt.generate_loopy(result)(cq, + x1=x1_np, x2=x2_np, + y1=y1_np, y2=y2_np, + z1=z1_np, z2=z2_np) + + _, out_dict2 = pt.generate_loopy(concatenated_result)(cq, + x1=x1_np, x2=x2_np, + y1=y1_np, y2=y2_np, + z1=z1_np, z2=z2_np) + assert out_dict1.keys() == out_dict2.keys() + + for key in out_dict1: + np.testing.assert_allclose(out_dict1[key], out_dict2[key]) + + +def test_concatenation_via_constant_expressions(ctx_factory): + + from pytato.transform.calls import CallSiteDependencyCollector + + rng = np.random.default_rng(0) + + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + def resampling(coords, iels): + return coords[iels] + + n_el = 1000 + n_dof = 20 + n_dim = 3 + + n_left_els = 17 + n_right_els = 29 + + coords_dofs_np = rng.random((n_el, n_dim, n_dof), np.float64) + left_bnd_iels_np = rng.integers(low=0, high=n_el, size=n_left_els) + right_bnd_iels_np = rng.integers(low=0, high=n_el, size=n_right_els) + + coords_dofs = pt.make_data_wrapper(coords_dofs_np) + left_bnd_iels = pt.make_data_wrapper(left_bnd_iels_np) + right_bnd_iels = pt.make_data_wrapper(right_bnd_iels_np) + + lcoords = pt.trace_call(resampling, coords_dofs, left_bnd_iels) + rcoords = pt.trace_call(resampling, coords_dofs, right_bnd_iels) + + result = pt.make_dict_of_named_arrays({"lcoords": lcoords, + "rcoords": rcoords}) + result = pt.transform.Deduplicator()(result) + result = pt.tag_all_calls_to_be_inlined(result) + + assert len(CallSiteDependencyCollector(())(result)) == 2 + concated_result = pt.concatenate_calls( + result, + lambda cs: pt.tags.FunctionIdentifier("resampling") in cs.call.function.tags + ) + assert len(CallSiteDependencyCollector(())(concated_result)) == 1 + + _, out_result = pt.generate_loopy(result)(cq) + np.testing.assert_allclose(out_result["lcoords"], + coords_dofs_np[left_bnd_iels_np]) + np.testing.assert_allclose(out_result["rcoords"], + coords_dofs_np[right_bnd_iels_np]) + + _, out_concated_result = pt.generate_loopy(result)(cq) + np.testing.assert_allclose(out_concated_result["lcoords"], + coords_dofs_np[left_bnd_iels_np]) + np.testing.assert_allclose(out_concated_result["rcoords"], + coords_dofs_np[right_bnd_iels_np]) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) diff --git a/test/test_distributed.py b/test/test_distributed.py index d78479e08..65214c4b0 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -555,12 +555,13 @@ def _test_dag_with_multiple_send_nodes_per_sent_array_inner(ctx_factory): x_np = rng.random((10, 4)) x = pt.make_data_wrapper(cla.to_device(queue, x_np)) y = 2 * x + ones = pt.ones(10) send1 = pt.staple_distributed_send( y, dest_rank=1, comm_tag=42, - stapled_to=pt.ones(10)) + stapled_to=ones) send2 = pt.staple_distributed_send( y, dest_rank=2, comm_tag=42, - stapled_to=pt.ones(10)) + stapled_to=ones) z = 4 * y dag = pt.make_dict_of_named_arrays({"z": z, "send1": send1, "send2": send2}) else: diff --git a/test/test_pytato.py b/test/test_pytato.py index da176f124..e6457f338 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -29,6 +29,7 @@ import dataclasses import sys +from contextlib import contextmanager import numpy as np import pytest @@ -514,8 +515,8 @@ def _assert_stripped_repr(ary: pt.Array, expected_repr: str): dtype='int64', expr=Product((Subscript(Variable('_in0'), (Variable('_0'), Variable('_1'))), - TypeCast(dtype('int64'), Subscript(Variable('_in1'), - (Variable('_0'), Variable('_1')))))), + Subscript(Variable('_in1'), + (Variable('_0'), Variable('_1'))))), bindings={'_in0': Placeholder(shape=(10, 4), dtype='int64', name='y'), '_in1': IndexLambda( shape=(10, 4), @@ -723,7 +724,7 @@ def test_small_dag_with_duplicates_count(): # Check that duplicates are correctly calculated assert node_count - num_duplicates == len( - pt.transform.DependencyMapper()(dag)) + pt.transform.DependencyMapper(err_on_collision=False)(dag)) assert node_count - num_duplicates == get_num_nodes( dag, count_duplicates=False) @@ -760,7 +761,7 @@ def test_large_dag_with_duplicates_count(): # Check that duplicates are correctly calculated assert node_count - num_duplicates == len( - pt.transform.DependencyMapper()(dag)) + pt.transform.DependencyMapper(err_on_collision=False)(dag)) assert node_count - num_duplicates == get_num_nodes( dag, count_duplicates=False) @@ -805,6 +806,8 @@ def post_visit(self, expr): assert expr.name == "x" expr, inp = construct_intestine_graph() + expr = pt.transform.Deduplicator()(expr) + result = pt.transform.rec_get_user_nodes(expr, inp) SubexprRecorder()(expr) @@ -932,112 +935,118 @@ def test_einsum_dot_axes_has_correct_dim(): assert len(einsum.axes) == einsum.ndim -def test_created_at(): - pt.set_traceback_tag_enabled() +@contextmanager +def enable_traceback_tag(): + try: + pt.set_traceback_tag_enabled(True) + yield + finally: + pt.set_traceback_tag_enabled(False) - a = pt.make_placeholder("a", (10, 10), "float64") - b = pt.make_placeholder("b", (10, 10), "float64") - # res1 and res2 are defined on different lines and should have different - # CreatedAt tags. - res1 = a+b - res2 = a+b +def test_created_at(): + with enable_traceback_tag(): + a = pt.make_placeholder("a", (10, 10), "float64") + b = pt.make_placeholder("b", (10, 10), "float64") + + # res1 and res2 are defined on different lines and should have different + # CreatedAt tags. + res1 = a+b + res2 = a+b - # res3 and res4 are defined on the same line and should have the same - # CreatedAt tags. - res3 = a+b; res4 = a+b # noqa: E702 + # res3 and res4 are defined on the same line and should have the same + # CreatedAt tags. + res3 = a+b; res4 = a+b # noqa: E702 - # {{{ Check that CreatedAt tags are handled correctly for equality/hashing + # {{{ Check that CreatedAt tags are handled correctly for equality/hashing - assert res1 == res2 == res3 == res4 - assert hash(res1) == hash(res2) == hash(res3) == hash(res4) + assert res1 == res2 == res3 == res4 + assert hash(res1) == hash(res2) == hash(res3) == hash(res4) - assert res1.non_equality_tags != res2.non_equality_tags - assert res3.non_equality_tags == res4.non_equality_tags - assert hash(res1.non_equality_tags) != hash(res2.non_equality_tags) - assert hash(res3.non_equality_tags) == hash(res4.non_equality_tags) + assert res1.non_equality_tags != res2.non_equality_tags + assert res3.non_equality_tags == res4.non_equality_tags + assert hash(res1.non_equality_tags) != hash(res2.non_equality_tags) + assert hash(res3.non_equality_tags) == hash(res4.non_equality_tags) - assert res1.tags == res2.tags == res3.tags == res4.tags - assert hash(res1.tags) == hash(res2.tags) == hash(res3.tags) == hash(res4.tags) + assert res1.tags == res2.tags == res3.tags == res4.tags + assert hash(res1.tags) == hash(res2.tags) == hash(res3.tags) == hash(res4.tags) - # }}} + # }}} - from pytato.tags import CreatedAt + from pytato.tags import CreatedAt - created_tag = frozenset({tag - for tag in res1.non_equality_tags - if isinstance(tag, CreatedAt)}) + created_tag = frozenset({tag + for tag in res1.non_equality_tags + if isinstance(tag, CreatedAt)}) - assert len(created_tag) == 1 + assert len(created_tag) == 1 - # {{{ Make sure the function name appears in the traceback + # {{{ Make sure the function name appears in the traceback - tag, = created_tag + tag, = created_tag - found = False + found = False - stacksummary = tag.traceback.to_stacksummary() - assert len(stacksummary) > 10 + stacksummary = tag.traceback.to_stacksummary() + assert len(stacksummary) > 10 - for frame in tag.traceback.frames: - if frame.name == "test_created_at" and "a+b" in frame.line: - found = True - break + for frame in tag.traceback.frames: + if frame.name == "test_created_at" and "a+b" in frame.line: + found = True + break - assert found + assert found - # }}} + # }}} - # {{{ Make sure that CreatedAt tags are in the visualization + # {{{ Make sure that CreatedAt tags are in the visualization - from pytato.visualization import get_dot_graph - s = get_dot_graph(res1) - assert "test_created_at" in s - assert "a+b" in s + from pytato.visualization import get_dot_graph + s = get_dot_graph(res1) + assert "test_created_at" in s + assert "a+b" in s - # }}} + # }}} - # {{{ Make sure only a single CreatedAt tag is created + # {{{ Make sure only a single CreatedAt tag is created - old_tag = tag + old_tag = tag - res1 = res1 + res2 + res1 = res1 + res2 - created_tag = frozenset({tag - for tag in res1.non_equality_tags - if isinstance(tag, CreatedAt)}) + created_tag = frozenset({tag + for tag in res1.non_equality_tags + if isinstance(tag, CreatedAt)}) - assert len(created_tag) == 1 + assert len(created_tag) == 1 - tag, = created_tag + tag, = created_tag - # Tag should be recreated - assert tag != old_tag + # Tag should be recreated + assert tag != old_tag - # }}} + # }}} - # {{{ Make sure that copying preserves the tag + # {{{ Make sure that copying preserves the tag - old_tag = tag + old_tag = tag - res1_new = pt.transform.map_and_copy(res1, lambda x: x) + res1_new = pt.transform.Deduplicator()(res1) - created_tag = frozenset({tag - for tag in res1_new.non_equality_tags - if isinstance(tag, CreatedAt)}) + created_tag = frozenset({tag + for tag in res1_new.non_equality_tags + if isinstance(tag, CreatedAt)}) - assert len(created_tag) == 1 + assert len(created_tag) == 1 - tag, = created_tag + tag, = created_tag - assert old_tag == tag + assert old_tag == tag - # }}} + # }}} # {{{ Test disabling traceback creation - pt.set_traceback_tag_enabled(False) - a = pt.make_placeholder("a", (10, 10), "float64") created_tag = frozenset({tag @@ -1160,7 +1169,7 @@ class ExistentTag(Tag): out = make_random_dag(rdagc_pt).tagged(ExistentTag()) - dag = pt.make_dict_of_named_arrays({"out": out}) + dag = pt.transform.Deduplicator()(pt.make_dict_of_named_arrays({"out": out})) # get_num_nodes() returns an extra DictOfNamedArrays node assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag) diff --git a/test/testlib.py b/test/testlib.py index a28dec67e..7d58df480 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -101,6 +101,7 @@ def __init__( rng: np.random.Generator, axis_len: int, use_numpy: bool, + allow_duplicate_nodes: bool = False, additional_generators: ( Sequence[tuple[int, Callable[[RandomDAGContext], Array]]] | None) = None @@ -115,6 +116,7 @@ def __init__( self.axis_len = axis_len self.past_results: list[Array] = [] self.use_numpy = use_numpy + self.allow_duplicate_nodes = allow_duplicate_nodes if additional_generators is None: additional_generators = [] @@ -156,6 +158,14 @@ def make_random_reshape( def make_random_dag_inner(rdagc: RandomDAGContext) -> Any: + if not rdagc.use_numpy and not rdagc.allow_duplicate_nodes: + def dedup(expr: Array) -> Array: + return pt.transform._verify_is_array(pt.transform.Deduplicator()(expr)) + + else: + def dedup(expr: Array) -> Array: + return expr + rng = rdagc.rng max_prob_hardcoded = 1500 @@ -166,7 +176,7 @@ def make_random_dag_inner(rdagc: RandomDAGContext) -> Any: v = rng.integers(0, max_prob_hardcoded + additional_prob) if v < 600: - return make_random_constant(rdagc, naxes=rng.integers(1, 3)) + return dedup(make_random_constant(rdagc, naxes=rng.integers(1, 3))) elif v < 1000: op1 = make_random_dag(rdagc) @@ -189,9 +199,9 @@ def make_random_dag_inner(rdagc: RandomDAGContext) -> Any: # just inserted a few new 1-long axes. Those need to go before we # return. if which_op in ["maximum", "minimum"]: - return rdagc.np.squeeze(getattr(rdagc.np, which_op)(op1, op2)) + return dedup(rdagc.np.squeeze(getattr(rdagc.np, which_op)(op1, op2))) else: - return rdagc.np.squeeze(which_op(op1, op2)) + return dedup(rdagc.np.squeeze(which_op(op1, op2))) elif v < 1075: op1 = make_random_dag(rdagc) @@ -199,24 +209,26 @@ def make_random_dag_inner(rdagc: RandomDAGContext) -> Any: if op1.ndim <= 1 and op2.ndim <= 1: continue - return op1 @ op2 + return dedup(op1 @ op2) elif v < 1275: if not rdagc.past_results: continue - return rdagc.past_results[rng.integers(0, len(rdagc.past_results))] + return dedup( + rdagc.past_results[rng.integers(0, len(rdagc.past_results))]) elif v < max_prob_hardcoded: result = make_random_dag(rdagc) - return rdagc.np.transpose( + return dedup( + rdagc.np.transpose( result, - tuple(rng.permuted(list(range(result.ndim))))) + tuple(rng.permuted(list(range(result.ndim)))))) else: base_prob = max_prob_hardcoded for fake_prob, gen_func in rdagc.additional_generators: if base_prob <= v < base_prob + fake_prob: - return gen_func(rdagc) + return dedup(gen_func(rdagc)) base_prob += fake_prob @@ -237,6 +249,14 @@ def make_random_dag(rdagc: RandomDAGContext) -> Any: of the array are of length :attr:`RandomDAGContext.axis_len` (there is at least one axis, but arbitrarily more may be present). """ + if not rdagc.use_numpy and not rdagc.allow_duplicate_nodes: + def dedup(expr: Array) -> Array: + return pt.transform._verify_is_array(pt.transform.Deduplicator()(expr)) + + else: + def dedup(expr: Array) -> Array: + return expr + rng = rdagc.rng result = make_random_dag_inner(rdagc) @@ -248,14 +268,15 @@ def make_random_dag(rdagc: RandomDAGContext) -> Any: subscript[rng.integers(0, result.ndim)] = int( rng.integers(0, rdagc.axis_len)) - return result[tuple(subscript)] + return dedup(result[tuple(subscript)]) elif v == 1: # reduce away an axis # FIXME do reductions other than sum? - return rdagc.np.sum( - result, axis=int(rng.integers(0, result.ndim))) + return dedup( + rdagc.np.sum( + result, axis=int(rng.integers(0, result.ndim)))) else: raise AssertionError() @@ -275,7 +296,8 @@ def get_random_pt_dag(seed: int, Sequence[tuple[int, Callable[[RandomDAGContext], Array]]] | None) = None, axis_len: int = 4, - convert_dws_to_placeholders: bool = False + convert_dws_to_placeholders: bool = False, + allow_duplicate_nodes: bool = False ) -> pt.DictOfNamedArrays: if additional_generators is None: additional_generators = [] @@ -286,6 +308,7 @@ def get_random_pt_dag(seed: int, rdagc_comm = RandomDAGContext(np.random.default_rng(seed=seed), axis_len=axis_len, use_numpy=False, + allow_duplicate_nodes=allow_duplicate_nodes, additional_generators=additional_generators) dag = pt.make_dict_of_named_arrays({"result": make_random_dag(rdagc_comm)})