diff --git a/pytato/array.py b/pytato/array.py index 2782b24c9..598fa8c7b 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1,4 +1,5 @@ from __future__ import annotations +from traceback import FrameSummary, StackSummary __copyright__ = """ Copyright (C) 2020 Andreas Kloeckner @@ -593,25 +594,37 @@ def _unary_op(self, op: Any) -> Array: non_equality_tags=_get_created_at_tag(), var_to_reduction_descr=immutabledict()) - __mul__ = partialmethod(_binary_op, operator.mul) - __rmul__ = partialmethod(_binary_op, operator.mul, reverse=True) + # NOTE: Initializing the expression to "prim.Product(expr1, expr2)" is + # essential as opposed to performing "expr1 * expr2". This is to account + # for pymbolic's implementation of the "*" operator which might not + # instantiate the node corresponding to the operation when one of + # the operands is the neutral element of the operation. + # + # For the same reason 'prim.(Sum|FloorDiv|Quotient)' is preferred over the + # python operators on the operands. - __add__ = partialmethod(_binary_op, operator.add) - __radd__ = partialmethod(_binary_op, operator.add, reverse=True) + __mul__ = partialmethod(_binary_op, lambda l, r: prim.Product((l, r))) + __rmul__ = partialmethod(_binary_op, lambda l, r: prim.Product((l, r)), + reverse=True) - __sub__ = partialmethod(_binary_op, operator.sub) - __rsub__ = partialmethod(_binary_op, operator.sub, reverse=True) + __add__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, r))) + __radd__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, r)), + reverse=True) - __floordiv__ = partialmethod(_binary_op, operator.floordiv) - __rfloordiv__ = partialmethod(_binary_op, operator.floordiv, reverse=True) + __sub__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, -r))) + __rsub__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, -r)), + reverse=True) - __truediv__ = partialmethod(_binary_op, operator.truediv, + __floordiv__ = partialmethod(_binary_op, prim.FloorDiv) + __rfloordiv__ = partialmethod(_binary_op, prim.FloorDiv, reverse=True) + + __truediv__ = partialmethod(_binary_op, prim.Quotient, get_result_type=_truediv_result_type) - __rtruediv__ = partialmethod(_binary_op, operator.truediv, + __rtruediv__ = partialmethod(_binary_op, prim.Quotient, get_result_type=_truediv_result_type, reverse=True) - __pow__ = partialmethod(_binary_op, operator.pow) - __rpow__ = partialmethod(_binary_op, operator.pow, reverse=True) + __pow__ = partialmethod(_binary_op, prim.Power) + __rpow__ = partialmethod(_binary_op, prim.Power, reverse=True) __neg__ = partialmethod(_unary_op, operator.neg) @@ -1488,8 +1501,7 @@ class Reshape(IndexRemappingBase): if __debug__: def __attrs_post_init__(self) -> None: - # FIXME: Get rid of this restriction - assert self.order == "C" + # assert self.non_equality_tags super().__attrs_post_init__() @property @@ -1981,8 +1993,7 @@ def reshape(array: Array, newshape: Union[int, Sequence[int]], """ :param array: array to be reshaped :param newshape: shape of the resulting array - :param order: ``"C"`` or ``"F"``. Layout order of the result array. Only - ``"C"`` allowed for now. + :param order: ``"C"`` or ``"F"``. Layout order of the resulting array. .. note:: @@ -2002,9 +2013,6 @@ def reshape(array: Array, newshape: Union[int, Sequence[int]], if not all(isinstance(axis_len, INT_CLASSES) for axis_len in array.shape): raise ValueError("reshape of arrays with symbolic lengths not allowed") - if order != "C": - raise NotImplementedError("Reshapes to a 'F'-ordered arrays") - newshape_explicit = [] for new_axislen in newshape: diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 6acd5389f..7752e7ec8 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1804,6 +1804,33 @@ def tag_user_nodes( # }}} +# {{{ BranchMorpher + +class BranchMorpher(CopyMapper): + """ + A mapper that replaces equal segments of graphs with identical objects. + """ + def __init__(self) -> None: + super().__init__() + self.result_cache: Dict[ArrayOrNames, ArrayOrNames] = {} + + def cache_key(self, expr: CachedMapperT) -> Any: + return (id(expr), expr) + + # type-ignore reason: incompatible with Mapper.rec + def rec(self, expr: MappedT) -> MappedT: # type: ignore[override] + rec_expr = super().rec(expr) + try: + # type-ignored because 'result_cache' maps to ArrayOrNames + return self.result_cache[rec_expr] # type: ignore[return-value] + except KeyError: + self.result_cache[rec_expr] = rec_expr + # type-ignored because of super-class' relaxed types + return rec_expr # type: ignore[no-any-return] + +# }}} + + # {{{ deduplicate_data_wrappers def _get_data_dedup_cache_key(ary: DataInterface) -> Hashable: @@ -1893,8 +1920,9 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames: len(data_wrapper_cache), data_wrappers_encountered - len(data_wrapper_cache)) - return array_or_names + return BranchMorpher()(array_or_names) # }}} + # vim: foldmethod=marker diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index f31200dd7..892fa2c20 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -37,7 +37,7 @@ AdvancedIndexInNoncontiguousAxes, NormalizedSlice, ShapeType, AbstractResultWithNamedArrays) -from pytato.scalar_expr import ScalarExpression, INT_CLASSES, IntegralT +from pytato.scalar_expr import ScalarExpression, INT_CLASSES from pytato.diagnostic import CannotBeLoweredToIndexLambda from pytato.tags import AssumeNonNegative from pytato.transform import Mapper @@ -51,30 +51,58 @@ def _get_reshaped_indices(expr: Reshape) -> Tuple[ScalarExpression, ...]: assert expr.size == 1 return () - if expr.order != "C": - raise NotImplementedError(expr.order) + if expr.order.upper() not in ["C", "F"]: + raise NotImplementedError("Order expected to be 'C' or 'F'", + f"(case insensitive) found {expr.order}") - newstrides: List[IntegralT] = [1] # reshaped array strides - for new_axis_len in reversed(expr.shape[1:]): - assert isinstance(new_axis_len, INT_CLASSES) - newstrides.insert(0, newstrides[0]*new_axis_len) + order = expr.order + oldshape = expr.array.shape + newshape = expr.shape - flattened_idx = sum(prim.Variable(f"_{i}")*stride - for i, stride in enumerate(newstrides)) + # {{{ compute strides + + oldstrides = [1] + oldstride_axes = (reversed(oldshape[1:]) if order == "C" else oldshape[:-1]) + + for ax_len in oldstride_axes: + assert isinstance(ax_len, INT_CLASSES) + oldstrides.append(oldstrides[-1]*ax_len) + + newstrides = [1] + newstride_axes = (reversed(newshape[1:]) if order == "C" else newshape[:-1]) + + for ax_len in newstride_axes: + assert isinstance(ax_len, INT_CLASSES) + newstrides.append(newstrides[-1]*ax_len) + + # }}} + + # {{{ compute size tills - oldstrides: List[IntegralT] = [1] # input array strides - for axis_len in reversed(expr.array.shape[1:]): - assert isinstance(axis_len, INT_CLASSES) - oldstrides.insert(0, oldstrides[0]*axis_len) + oldsizetills = [oldshape[-1] if order == "C" else oldshape[0]] + oldsizetill_ax = (oldshape[:-1][::-1] if order == "C" else oldshape[:-1]) + for ax_len in oldsizetill_ax: + oldsizetills.append(oldsizetills[-1]*ax_len) + + # }}} + + # {{{ if order is C, then computed info is backwards + + if order == "C": + oldstrides = oldstrides[::-1] + newstrides = newstrides[::-1] + oldsizetills = oldsizetills[::-1] + + # }}} + + flattened_idx = sum(prim.Variable(f"_{i}")*stride + for i, stride in enumerate(newstrides)) - assert isinstance(expr.array.shape[-1], INT_CLASSES) - oldsizetills = [expr.array.shape[-1]] # input array size till for axes idx - for old_axis_len in reversed(expr.array.shape[:-1]): - assert isinstance(old_axis_len, INT_CLASSES) - oldsizetills.insert(0, oldsizetills[0]*old_axis_len) + ret = tuple( + (flattened_idx % sizetill) // stride + for stride, sizetill in zip(oldstrides, oldsizetills)) - return tuple(((flattened_idx % sizetill) // stride) - for stride, sizetill in zip(oldstrides, oldsizetills)) + return ret class ToIndexLambdaMixin: diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 7f015d739..638c4d08b 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -35,7 +35,7 @@ """ -from typing import (TYPE_CHECKING, Type, Set, Tuple, List, Dict, FrozenSet, +from typing import (TYPE_CHECKING, Type, Set, Tuple, List, Dict, Mapping, Iterable, Any, TypeVar, cast) from bidict import bidict from pytato.scalar_expr import SCALAR_CLASSES @@ -58,7 +58,7 @@ from pytato.diagnostic import UnknownIndexLambdaExpr from pytools import UniqueNameGenerator -from pytools.tag import Tag +from pytools.tag import Tag, UniqueTag import logging logger = logging.getLogger(__name__) @@ -70,6 +70,22 @@ GraphNodeT = TypeVar("GraphNodeT") +# {{{ IgnoredForPropagationTag + +class AxisIgnoredForPropagationTag(UniqueTag): + """ + Used to influence equality constraints when determining which axes tags + are allowed to propagate along. + + The intended use case for this is to prevent the axes of a matrix used to, + for example, differentiate a tensor of DOF data from picking up on the + unique tags attached to the axes of the tensor. + """ + pass + +# }}} + + # {{{ AxesTagsEquationCollector class AxesTagsEquationCollector(Mapper): @@ -167,6 +183,8 @@ def record_equations_from_axes_tags(self, ary: Array) -> None: Records equations for *ary*\'s axis tags of type :attr:`tag_t`. """ for iaxis, axis in enumerate(ary.axes): + if axis.tags_of_type(AxisIgnoredForPropagationTag): + continue lhs_var = self.get_var_for_axis(ary, iaxis) for tag in axis.tags_of_type(self.tag_t): rhs_var = self.get_var_for_tag(tag) @@ -492,11 +510,12 @@ def map_einsum(self, expr: Einsum) -> None: descr_to_var[EinsumElementwiseAxis(iaxis)] = self.get_var_for_axis(expr, iaxis) - for access_descrs, arg in zip(expr.access_descriptors, - expr.args): + for access_descrs, arg in zip(expr.access_descriptors, expr.args): for iarg_axis, descr in enumerate(access_descrs): - in_tag_var = self.get_var_for_axis(arg, iarg_axis) + if arg.axes[iarg_axis].tags_of_type(AxisIgnoredForPropagationTag): + continue + in_tag_var = self.get_var_for_axis(arg, iarg_axis) if descr in descr_to_var: self.record_equation(descr_to_var[descr], in_tag_var) else: @@ -556,38 +575,7 @@ def map_named_call_result(self, expr: NamedCallResult) -> Array: # }}} -def _get_propagation_graph_from_constraints( - equations: List[Tuple[str, str]]) -> Mapping[str, FrozenSet[str]]: - from immutabledict import immutabledict - propagation_graph: Dict[str, Set[str]] = {} - for lhs, rhs in equations: - assert lhs != rhs - propagation_graph.setdefault(lhs, set()).add(rhs) - propagation_graph.setdefault(rhs, set()).add(lhs) - - return immutabledict({k: frozenset(v) - for k, v in propagation_graph.items()}) - - -def get_reachable_nodes(undirected_graph: Mapping[GraphNodeT, Iterable[GraphNodeT]], - source_node: GraphNodeT) -> FrozenSet[GraphNodeT]: - """ - Returns a :class:`frozenset` of all nodes in *undirected_graph* that are - reachable from *source_node*. - """ - nodes_visited: Set[GraphNodeT] = set() - nodes_to_visit = {source_node} - while nodes_to_visit: - current_node = nodes_to_visit.pop() - nodes_visited.add(current_node) - - neighbors = undirected_graph[current_node] - nodes_to_visit.update({node - for node in neighbors - if node not in nodes_visited}) - - return frozenset(nodes_visited) - +# {{{ AxisTagAttacher class AxisTagAttacher(CopyMapper): """ @@ -659,6 +647,8 @@ def __call__(self, expr: ArrayOrNames) -> ArrayOrNames: # type: ignore[override assert isinstance(result, (Array, AbstractResultWithNamedArrays)) return result +# }}} + def unify_axes_tags( expr: ArrayOrNames, @@ -693,19 +683,25 @@ def unify_axes_tags( # Defn. A Propagation graph is a graph where nodes denote variables and an # edge between 2 nodes denotes an equality criterion. - propagation_graph = _get_propagation_graph_from_constraints( - equations_collector.equations) + from pytools.graph import ( + get_propagation_graph_from_constraints, + get_reachable_nodes + ) known_tag_vars = frozenset(equations_collector.known_tag_to_var.values()) axis_to_solved_tags: Dict[Tuple[Array, int], Set[Tag]] = {} + propagation_graph = get_propagation_graph_from_constraints( + equations_collector.equations, + ) + for tag, var in equations_collector.known_tag_to_var.items(): - for reachable_var in (get_reachable_nodes(propagation_graph, var) - - known_tag_vars): - axis_to_solved_tags.setdefault( - equations_collector.axis_to_var.inverse[reachable_var], - set() - ).add(tag) + reachable_nodes = get_reachable_nodes(propagation_graph, var) + for reachable_var in (reachable_nodes - known_tag_vars): + axis_to_solved_tags.setdefault( + equations_collector.axis_to_var.inverse[reachable_var], + set() + ).add(tag) return AxisTagAttacher(axis_to_solved_tags, tag_corresponding_redn_descr=unify_redn_descrs, diff --git a/test/test_codegen.py b/test/test_codegen.py index 0f1456d9b..6320810c0 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1610,7 +1610,7 @@ def test_zero_size_cl_array_dedup(ctx_factory): x4 = pt.make_data_wrapper(x_cl2) out = pt.make_dict_of_named_arrays({"out1": 2*x1, - "out2": 2*x2, + "out2": 3*x2, "out3": x3 + x4 })