From f5c377eea26e84a66397a976c5dd544c6b4e58a0 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 22 Mar 2022 22:53:41 -0500 Subject: [PATCH 001/178] Add tag to store array creation traceback --- pytato/array.py | 6 +++++- pytato/tags.py | 10 ++++++++++ test/test_pytato.py | 15 +++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index aa17a2f4b..7b3b1e0f6 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -419,8 +419,12 @@ class Array(Taggable): __array_priority__ = 1 # disallow numpy arithmetic to take precedence def __init__(self, axes: AxesT, tags: FrozenSet[Tag]) -> None: + import traceback + v = "".join(traceback.format_stack()) + from pytato.tags import CreatedAt + c = CreatedAt(v) self.axes = axes - self.tags = tags + self.tags = frozenset({*tags, c}) def copy(self: ArrayT, **kwargs: Any) -> ArrayT: for field in self._fields: diff --git a/pytato/tags.py b/pytato/tags.py index f1794c177..f1ed4984c 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -101,3 +101,13 @@ class AssumeNonNegative(Tag): :class:`~pytato.target.Target` that all entries of the tagged array are non-negative. """ + + +@tag_dataclass +class CreatedAt(UniqueTag): + """ + A tag attached to a :class:`~pytato.Array` to store the traceback + of where it was created. + """ + + traceback: str diff --git a/test/test_pytato.py b/test/test_pytato.py index 5f602a240..940ff2353 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -790,6 +790,21 @@ def test_einsum_dot_axes_has_correct_dim(): assert len(einsum.axes) == einsum.ndim +def test_created_at(): + a = pt.make_placeholder("a", (10, 10), "float64") + b = pt.make_placeholder("b", (10, 10), "float64") + + res = a+b + + from pytato.tags import CreatedAt + assert any(isinstance(tag, CreatedAt) for tag in res.tags) + + # Make sure the function name appears in the traceback + for tag in res.tags: + if isinstance(tag, CreatedAt): + assert "test_created_at" in tag.traceback + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) From 4c32cb69b862fcca6f40502a946af1cb13818f8d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 23 Mar 2022 11:01:28 -0500 Subject: [PATCH 002/178] don't make it a unique tag --- pytato/tags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/tags.py b/pytato/tags.py index f1ed4984c..285062a5c 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -104,7 +104,7 @@ class AssumeNonNegative(Tag): @tag_dataclass -class CreatedAt(UniqueTag): +class CreatedAt(Tag): """ A tag attached to a :class:`~pytato.Array` to store the traceback of where it was created. From 56fbf4c9da48b1463044b0df1b3d365427a329d7 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 23 Mar 2022 16:50:24 -0500 Subject: [PATCH 003/178] adds a common _get_default_tags --- pytato/array.py | 40 +++++++++++++++++++++++++++++++++------- pytato/cmath.py | 9 ++++++--- pytato/utils.py | 8 ++++++-- 3 files changed, 45 insertions(+), 12 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index aa17a2f4b..3e66d8796 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -474,6 +474,7 @@ def ndim(self) -> int: def T(self) -> Array: return AxisPermutation(self, tuple(range(self.ndim)[::-1]), + tags=_get_default_tags(), axes=_get_default_axes(self.ndim)) @memoize_method @@ -539,6 +540,7 @@ def _unary_op(self, op: Any) -> Array: shape=self.shape, dtype=self.dtype, bindings=bindings, + tags=_get_default_tags(), axes=_get_default_axes(self.ndim)) __mul__ = partialmethod(_binary_op, operator.mul) @@ -1133,6 +1135,7 @@ def einsum(subscripts: str, *operands: Array) -> Einsum: access_descriptors.append(access_descriptor) return Einsum(tuple(access_descriptors), operands, + tags=_get_default_tags(), axes=_get_default_axes(len({descr for descr in index_to_descr.values() if isinstance(descr, @@ -1672,6 +1675,10 @@ def _get_default_axes(ndim: int) -> AxesT: return tuple(Axis(frozenset()) for _ in range(ndim)) +def _get_default_tags() -> TagsType: + return frozenset() + + def _get_matmul_ndim(ndim1: int, ndim2: int) -> int: if ndim1 == 1 and ndim2 == 1: return 0 @@ -1738,7 +1745,9 @@ def roll(a: Array, shift: int, axis: Optional[int] = None) -> Array: if shift == 0: return a - return Roll(a, shift, axis, axes=_get_default_axes(a.ndim)) + return Roll(a, shift, axis, + tags=_get_default_tags(), + axes=_get_default_axes(a.ndim)) def transpose(a: Array, axes: Optional[Sequence[int]] = None) -> Array: @@ -1758,7 +1767,9 @@ def transpose(a: Array, axes: Optional[Sequence[int]] = None) -> Array: if set(axes) != set(range(a.ndim)): raise ValueError("repeated or out-of-bounds axes detected") - return AxisPermutation(a, tuple(axes), axes=_get_default_axes(a.ndim)) + return AxisPermutation(a, tuple(axes), + tags=_get_default_tags(), + axes=_get_default_axes(a.ndim)) def stack(arrays: Sequence[Array], axis: int = 0) -> Array: @@ -1790,7 +1801,9 @@ def stack(arrays: Sequence[Array], axis: int = 0) -> Array: if not (0 <= axis <= arrays[0].ndim): raise ValueError("invalid axis") - return Stack(tuple(arrays), axis, axes=_get_default_axes(arrays[0].ndim+1)) + return Stack(tuple(arrays), axis, + tags=_get_default_tags(), + axes=_get_default_axes(arrays[0].ndim+1)) def concatenate(arrays: Sequence[Array], axis: int = 0) -> Array: @@ -1823,7 +1836,9 @@ def shape_except_axis(ary: Array) -> ShapeType: if not (0 <= axis <= arrays[0].ndim): raise ValueError("invalid axis") - return Concatenate(tuple(arrays), axis, axes=_get_default_axes(arrays[0].ndim)) + return Concatenate(tuple(arrays), axis, + tags=_get_default_tags(), + axes=_get_default_axes(arrays[0].ndim)) def reshape(array: Array, newshape: Union[int, Sequence[int]], @@ -1885,6 +1900,7 @@ def reshape(array: Array, newshape: Union[int, Sequence[int]], f" into {newshape}") return Reshape(array, tuple(newshape_explicit), order, + tags=_get_default_tags(), axes=_get_default_axes(len(newshape_explicit))) @@ -1925,7 +1941,8 @@ def make_placeholder(name: str, raise ValueError("'axes' dimensionality mismatch:" f" expected {len(shape)}, got {len(axes)}.") - return Placeholder(name, shape, dtype, axes=axes, tags=tags) + return Placeholder(name, shape, dtype, axes=axes, + tags=(tags | _get_default_tags())) def make_size_param(name: str, @@ -1939,7 +1956,7 @@ def make_size_param(name: str, :param tags: implementation tags """ _check_identifier(name, optional=False) - return SizeParam(name, tags=tags) + return SizeParam(name, tags=(tags | _get_default_tags())) def make_data_wrapper(data: DataInterface, @@ -1967,7 +1984,9 @@ def make_data_wrapper(data: DataInterface, raise ValueError("'axes' dimensionality mismatch:" f" expected {len(shape)}, got {len(axes)}.") - return DataWrapper(name, data, shape, axes=axes, tags=tags) + return DataWrapper(name, data, shape, + axes=axes, + tags=(tags | _get_default_tags())) # }}} @@ -1985,6 +2004,7 @@ def full(shape: ConvertibleToShape, fill_value: ScalarType, shape = normalize_shape(shape) dtype = np.dtype(dtype) return IndexLambda(dtype.type(fill_value), shape, dtype, {}, + tags=_get_default_tags(), axes=_get_default_axes(len(shape))) @@ -2029,6 +2049,7 @@ def eye(N: int, M: Optional[int] = None, k: int = 0, # noqa: N803 return IndexLambda(parse(f"1 if ((_1 - _0) == {k}) else 0"), shape=(N, M), dtype=dtype, bindings={}, + tags=_get_default_tags(), axes=_get_default_axes(2)) # }}} @@ -2122,6 +2143,7 @@ def arange(*args: Any, **kwargs: Any) -> Array: from pymbolic.primitives import Variable return IndexLambda(start + Variable("_0") * step, shape=(size,), dtype=dtype, bindings={}, + tags=_get_default_tags(), axes=_get_default_axes(1)) # }}} @@ -2222,6 +2244,7 @@ def logical_not(x: ArrayOrScalar) -> Union[Array, bool]: shape=x.shape, dtype=np.dtype(np.bool8), bindings={"_in0": x}, + tags=_get_default_tags(), axes=_get_default_axes(len(x.shape))) # }}} @@ -2274,6 +2297,7 @@ def where(condition: ArrayOrScalar, shape=result_shape, dtype=dtype, bindings=bindings, + tags=_get_default_tags(), axes=_get_default_axes(len(result_shape))) # }}} @@ -2330,6 +2354,7 @@ def make_index_lambda( bindings=bindings, shape=shape, dtype=dtype, + tags=_get_default_tags(), axes=_get_default_axes(len(shape))) # }}} @@ -2407,6 +2432,7 @@ def broadcast_to(array: Array, shape: ShapeType) -> Array: shape=shape, dtype=array.dtype, bindings={"in": array}, + tags=_get_default_tags(), axes=_get_default_axes(len(shape))) diff --git a/pytato/cmath.py b/pytato/cmath.py index 8d4fc1da0..5ada73e6b 100644 --- a/pytato/cmath.py +++ b/pytato/cmath.py @@ -59,7 +59,7 @@ import pymbolic.primitives as prim from typing import Tuple, Optional from pytato.array import (Array, ArrayOrScalar, IndexLambda, _dtype_any, - _get_default_axes) + _get_default_axes, _get_default_tags) from pytato.scalar_expr import SCALAR_CLASSES from pymbolic import var @@ -110,8 +110,11 @@ def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...], assert ret_dtype is not None return IndexLambda( - prim.Call(var(f"pytato.c99.{func_name}"), tuple(sym_args)), - shape, ret_dtype, bindings, axes=_get_default_axes(len(shape))) + prim.Call(var(f"pytato.c99.{func_name}"), + tuple(sym_args)), + shape, ret_dtype, bindings, + tags=_get_default_tags(), + axes=_get_default_axes(len(shape))) def abs(x: Array) -> ArrayOrScalar: diff --git a/pytato/utils.py b/pytato/utils.py index 51959e2f5..c475e1da1 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -169,7 +169,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501 ) -> ArrayOrScalar: - from pytato.array import _get_default_axes + from pytato.array import _get_default_axes, _get_default_tags if isinstance(a1, SCALAR_CLASSES): a1 = np.dtype(type(a1)).type(a1) @@ -196,6 +196,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, shape=result_shape, dtype=result_dtype, bindings=bindings, + tags=_get_default_tags(), axes=_get_default_axes(len(result_shape))) @@ -461,7 +462,7 @@ def _normalized_slice_len(slice_: NormalizedSlice) -> ShapeComponent: def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Array: from pytato.diagnostic import CannotBroadcastError - from pytato.array import _get_default_axes + from pytato.array import _get_default_axes, _get_default_tags # {{{ handle ellipsis @@ -543,18 +544,21 @@ def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Arra return AdvancedIndexInNoncontiguousAxes( ary, tuple(normalized_indices), + tags=_get_default_tags(), axes=_get_default_axes(len(array_idx_shape) + len(i_basic_indices))) else: return AdvancedIndexInContiguousAxes( ary, tuple(normalized_indices), + tags=_get_default_tags(), axes=_get_default_axes(len(array_idx_shape) + len(i_basic_indices))) else: # basic indexing expression return BasicIndex(ary, tuple(normalized_indices), + tags=_get_default_tags(), axes=_get_default_axes( len([idx for idx in normalized_indices From 8fee6d4de8e9134333d29a4d0ae6cb3fd0810248 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 23 Mar 2022 17:19:38 -0500 Subject: [PATCH 004/178] Change back to UniqueTag MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Andreas Klöckner --- pytato/tags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/tags.py b/pytato/tags.py index 285062a5c..f1ed4984c 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -104,7 +104,7 @@ class AssumeNonNegative(Tag): @tag_dataclass -class CreatedAt(Tag): +class CreatedAt(UniqueTag): """ A tag attached to a :class:`~pytato.Array` to store the traceback of where it was created. From bdff59ceb4e5e993e11b0e5c6ac4ed8cb5d0a459 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 23 Mar 2022 17:42:39 -0500 Subject: [PATCH 005/178] use _get_default_tags --- pytato/array.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index decfd4d1f..a90354e93 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -419,12 +419,8 @@ class Array(Taggable): __array_priority__ = 1 # disallow numpy arithmetic to take precedence def __init__(self, axes: AxesT, tags: FrozenSet[Tag]) -> None: - import traceback - v = "".join(traceback.format_stack()) - from pytato.tags import CreatedAt - c = CreatedAt(v) self.axes = axes - self.tags = frozenset({*tags, c}) + self.tags = tags def copy(self: ArrayT, **kwargs: Any) -> ArrayT: for field in self._fields: @@ -1680,7 +1676,12 @@ def _get_default_axes(ndim: int) -> AxesT: def _get_default_tags() -> TagsType: - return frozenset() + import traceback + from pytato.tags import CreatedAt + + v = "".join(traceback.format_stack()) + c = CreatedAt(v) + return frozenset((c,)) def _get_matmul_ndim(ndim1: int, ndim2: int) -> int: From 6d181447ee2564097c24a3d8658bef378400274a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 23 Mar 2022 19:01:23 -0500 Subject: [PATCH 006/178] store a tupleized StackSummary --- pytato/array.py | 7 +++++-- test/test_pytato.py | 10 ++++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index a90354e93..ed56cbda5 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1679,8 +1679,11 @@ def _get_default_tags() -> TagsType: import traceback from pytato.tags import CreatedAt - v = "".join(traceback.format_stack()) - c = CreatedAt(v) + # extract_stack returns a StackSummary, which is a list + # You can restore the StackSummary object by calling + # StackSummary.from_list(c.traceback) + stack_summary = traceback.extract_stack() + c = CreatedAt(tuple(tuple(t) for t in tuple(stack_summary))) return frozenset((c,)) diff --git a/test/test_pytato.py b/test/test_pytato.py index 940ff2353..523aa5639 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -797,12 +797,18 @@ def test_created_at(): res = a+b from pytato.tags import CreatedAt - assert any(isinstance(tag, CreatedAt) for tag in res.tags) + + found = False # Make sure the function name appears in the traceback for tag in res.tags: if isinstance(tag, CreatedAt): - assert "test_created_at" in tag.traceback + for line in tag.traceback: + if line[2] == "test_created_at": + found = True + break + + assert found if __name__ == "__main__": From 2c3faebb4729b6bfd6140aa7410d03edf59a7d55 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 24 Mar 2022 11:34:35 -0500 Subject: [PATCH 007/178] work around mypy --- pytato/tags.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytato/tags.py b/pytato/tags.py index f1ed4984c..9a4f5c40e 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -12,6 +12,7 @@ """ +from dataclasses import dataclass from pytools.tag import Tag, UniqueTag, tag_dataclass @@ -103,7 +104,9 @@ class AssumeNonNegative(Tag): """ -@tag_dataclass +# See https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues +# on why this can not be '@tag_dataclass'. +@dataclass(init=True, eq=True, frozen=True, repr=True) class CreatedAt(UniqueTag): """ A tag attached to a :class:`~pytato.Array` to store the traceback From 739f3d3032e34e88b8f8e3d9df8588f48224a0c7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 24 Mar 2022 13:23:24 -0500 Subject: [PATCH 008/178] more line fixes --- pytato/tags.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytato/tags.py b/pytato/tags.py index 9a4f5c40e..dd35c5dc1 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -14,6 +14,7 @@ from dataclasses import dataclass from pytools.tag import Tag, UniqueTag, tag_dataclass +from typing import Tuple, Any # {{{ pre-defined tag: ImplementationStrategy @@ -104,7 +105,8 @@ class AssumeNonNegative(Tag): """ -# See https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues +# See +# https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues # on why this can not be '@tag_dataclass'. @dataclass(init=True, eq=True, frozen=True, repr=True) class CreatedAt(UniqueTag): @@ -113,4 +115,4 @@ class CreatedAt(UniqueTag): of where it was created. """ - traceback: str + traceback: Tuple[Tuple[Any, ...], ...] From 4fd3d64439b891abac11571e00d49fd646e1be05 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 14:33:03 -0500 Subject: [PATCH 009/178] use a class for the traceback instead of tuples --- pytato/array.py | 31 ++++++++++++++++++++++++++----- pytato/tags.py | 3 ++- test/test_pytato.py | 4 ++-- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 11432184b..8757b537b 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1,4 +1,5 @@ from __future__ import annotations +from traceback import FrameSummary, StackSummary __copyright__ = """ Copyright (C) 2020 Andreas Kloeckner @@ -1675,15 +1676,35 @@ def _get_default_axes(ndim: int) -> AxesT: return tuple(Axis(frozenset()) for _ in range(ndim)) +@dataclass(frozen=True, eq=True) +class _PytatoFrameSummary: + filename: str + lineno: int + name: str + line: str + + +class _PytatoStackSummary(Tag): + def __init__(self, stack_summary: StackSummary) -> None: + self.frames: List[_PytatoFrameSummary] = [] + for s in stack_summary: + pfs = _PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) + self.frames.append(pfs) + + def to_stacksummary(self) -> StackSummary: + frames = [] + for f in self.frames: + frames.append(FrameSummary(f.filename, f.lineno, f.name, line=f.line)) + + # type-ignore-reason: from_list also takes List[FrameSummary] + return StackSummary.from_list(frames) # type: ignore[arg-type] + + def _get_default_tags() -> TagsType: import traceback from pytato.tags import CreatedAt - # extract_stack returns a StackSummary, which is a list - # You can restore the StackSummary object by calling - # StackSummary.from_list(c.traceback) - stack_summary = traceback.extract_stack() - c = CreatedAt(tuple(tuple(t) for t in tuple(stack_summary))) + c = CreatedAt(_PytatoStackSummary(traceback.extract_stack())) return frozenset((c,)) diff --git a/pytato/tags.py b/pytato/tags.py index dd35c5dc1..4f2bdaef5 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -15,6 +15,7 @@ from dataclasses import dataclass from pytools.tag import Tag, UniqueTag, tag_dataclass from typing import Tuple, Any +from pytato.array import _PytatoStackSummary # {{{ pre-defined tag: ImplementationStrategy @@ -115,4 +116,4 @@ class CreatedAt(UniqueTag): of where it was created. """ - traceback: Tuple[Tuple[Any, ...], ...] + traceback: _PytatoStackSummary diff --git a/test/test_pytato.py b/test/test_pytato.py index 523aa5639..0e090ff0d 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -803,8 +803,8 @@ def test_created_at(): # Make sure the function name appears in the traceback for tag in res.tags: if isinstance(tag, CreatedAt): - for line in tag.traceback: - if line[2] == "test_created_at": + for frame in tag.traceback.frames: + if frame.name == "test_created_at": found = True break From 770255024b439f1e7d157d18d5e1607fb6dcc783 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 14:36:42 -0500 Subject: [PATCH 010/178] also test to_stacksummary --- test/test_pytato.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index 0e090ff0d..817aad5fa 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -803,6 +803,7 @@ def test_created_at(): # Make sure the function name appears in the traceback for tag in res.tags: if isinstance(tag, CreatedAt): + _unused = tag.traceback.to_stacksummary() for frame in tag.traceback.frames: if frame.name == "test_created_at": found = True From 7a8655770198d0902f3331479e0775da514f4956 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 14:41:18 -0500 Subject: [PATCH 011/178] flake8 --- pytato/tags.py | 1 - test/test_pytato.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytato/tags.py b/pytato/tags.py index 4f2bdaef5..77d2fe218 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -14,7 +14,6 @@ from dataclasses import dataclass from pytools.tag import Tag, UniqueTag, tag_dataclass -from typing import Tuple, Any from pytato.array import _PytatoStackSummary diff --git a/test/test_pytato.py b/test/test_pytato.py index 817aad5fa..d642c7486 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -803,7 +803,7 @@ def test_created_at(): # Make sure the function name appears in the traceback for tag in res.tags: if isinstance(tag, CreatedAt): - _unused = tag.traceback.to_stacksummary() + _unused = tag.traceback.to_stacksummary() # noqa for frame in tag.traceback.frames: if frame.name == "test_created_at": found = True From 31bcda8d3406fbf2d1f4b36a43c1b2e26bf78104 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 28 Mar 2022 15:07:19 -0500 Subject: [PATCH 012/178] Add remove_tags_of_type --- pytato/transform.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pytato/transform.py b/pytato/transform.py index 4af6f8ab3..a4d3a1f88 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -74,6 +74,7 @@ .. autofunction:: copy_dict_of_named_arrays .. autofunction:: get_dependencies .. autofunction:: map_and_copy +.. autofunction:: remove_tags_of_type .. autofunction:: materialize_with_mpms Dict representation of DAGs @@ -1031,6 +1032,21 @@ def map_and_copy(expr: ArrayOrNames, return CachedMapAndCopyMapper(map_fn)(expr) +def remove_tags_of_type(tag_types: Union[type, Tuple[type]], expr: ArrayOrNames + ) -> ArrayOrNames: + def process_node(expr: ArrayOrNames) -> ArrayOrNames: + if isinstance(expr, Array): + return expr.copy(tags=frozenset({ + tag for tag in expr.tags + if not isinstance(tag, tag_types)})) + elif isinstance(expr, AbstractResultWithNamedArrays): + return expr + else: + raise AssertionError() + + return map_and_copy(expr, process_node) + + def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: r""" Materialize nodes in *expr* with MPMS materialization strategy. From 5c3222229519d570b9da0d89379c76784779c418 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 28 Mar 2022 15:07:38 -0500 Subject: [PATCH 013/178] test_array_dot_repr: Remove CreatedAt tags before comparing --- test/test_pytato.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index d642c7486..c4f3c6e6e 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -27,6 +27,8 @@ import sys +from typing import cast + import numpy as np import pytest @@ -462,7 +464,12 @@ def test_array_dot_repr(): x = pt.make_placeholder("x", (10, 4), np.int64) y = pt.make_placeholder("y", (10, 4), np.int64) + from pytato.transform import remove_tags_of_type + from pytato.tags import CreatedAt + def _assert_stripped_repr(ary: pt.Array, expected_repr: str): + ary = cast(pt.Array, remove_tags_of_type(CreatedAt, ary)) + expected_str = "".join([c for c in repr(ary) if c not in [" ", "\n"]]) result_str = "".join([c for c in expected_repr if c not in [" ", "\n"]]) assert expected_str == result_str From 2d0358fbdd4ea298022be4587eb919cf6a179569 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 15:09:29 -0500 Subject: [PATCH 014/178] only add CreatedAt in debug mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Andreas Klöckner --- pytato/array.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 8757b537b..a59414312 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1704,8 +1704,9 @@ def _get_default_tags() -> TagsType: import traceback from pytato.tags import CreatedAt - c = CreatedAt(_PytatoStackSummary(traceback.extract_stack())) - return frozenset((c,)) + if __debug__: + c = CreatedAt(_PytatoStackSummary(traceback.extract_stack())) + return frozenset((c,)) def matmul(x1: Array, x2: Array) -> Array: From 1c275fd80cca82a6f584f7b20b33549c3ff977be Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 15:26:10 -0500 Subject: [PATCH 015/178] restructure test_created_at --- test/test_pytato.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index c4f3c6e6e..bd4dbdd26 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -805,16 +805,20 @@ def test_created_at(): from pytato.tags import CreatedAt + created_tag = res.tags_of_type(CreatedAt) + + assert len(created_tag) == 1 + + tag, = created_tag + found = False # Make sure the function name appears in the traceback - for tag in res.tags: - if isinstance(tag, CreatedAt): - _unused = tag.traceback.to_stacksummary() # noqa - for frame in tag.traceback.frames: - if frame.name == "test_created_at": - found = True - break + _unused = tag.traceback.to_stacksummary() # noqa + for frame in tag.traceback.frames: + if frame.name == "test_created_at": + found = True + break assert found From 02362bfaf8f6d6da00e2b2e639b9b6773866aa14 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 15:27:14 -0500 Subject: [PATCH 016/178] make _PytatoStackSummary a dataclass --- pytato/array.py | 46 +++++++++++++++++++++++++++++++++++---------- pytato/transform.py | 2 +- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index a59414312..3ed376caf 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1678,35 +1678,61 @@ def _get_default_axes(ndim: int) -> AxesT: @dataclass(frozen=True, eq=True) class _PytatoFrameSummary: + """Class to store a single call frame.""" filename: str lineno: int name: str line: str + def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: + key_builder.rec(key_hash, + (self.__class__.__module__, self.__class__.__qualname__)) -class _PytatoStackSummary(Tag): - def __init__(self, stack_summary: StackSummary) -> None: - self.frames: List[_PytatoFrameSummary] = [] - for s in stack_summary: - pfs = _PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) - self.frames.append(pfs) + from dataclasses import fields + # Fields are ordered consistently, so ordered hashing is OK. + # + # No need to dispatch to superclass: fields() automatically gives us + # fields from the entire class hierarchy. + for f in fields(self): + key_builder.rec(key_hash, getattr(self, f.name)) + + +@dataclass(frozen=True, eq=True) +class _PytatoStackSummary: + """Class to store a list of :class:`_PytatoFrameSummary` call frames.""" + frames: Tuple[_PytatoFrameSummary, ...] def to_stacksummary(self) -> StackSummary: - frames = [] - for f in self.frames: - frames.append(FrameSummary(f.filename, f.lineno, f.name, line=f.line)) + frames = [FrameSummary(f.filename, f.lineno, f.name, line=f.line) + for f in self.frames] # type-ignore-reason: from_list also takes List[FrameSummary] return StackSummary.from_list(frames) # type: ignore[arg-type] + def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: + key_builder.rec(key_hash, + (self.__class__.__module__, self.__class__.__qualname__)) + + from dataclasses import fields + # Fields are ordered consistently, so ordered hashing is OK. + # + # No need to dispatch to superclass: fields() automatically gives us + # fields from the entire class hierarchy. + for f in fields(self): + key_builder.rec(key_hash, getattr(self, f.name)) + def _get_default_tags() -> TagsType: import traceback from pytato.tags import CreatedAt if __debug__: - c = CreatedAt(_PytatoStackSummary(traceback.extract_stack())) + frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) + for s in traceback.extract_stack()) + c = CreatedAt(_PytatoStackSummary(frames)) return frozenset((c,)) + else: + return frozenset() def matmul(x1: Array, x2: Array) -> Array: diff --git a/pytato/transform.py b/pytato/transform.py index a4d3a1f88..fe39e1243 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -1033,7 +1033,7 @@ def map_and_copy(expr: ArrayOrNames, def remove_tags_of_type(tag_types: Union[type, Tuple[type]], expr: ArrayOrNames - ) -> ArrayOrNames: + ) -> ArrayOrNames: def process_node(expr: ArrayOrNames) -> ArrayOrNames: if isinstance(expr, Array): return expr.copy(tags=frozenset({ From 7b1f7b81bbb33c0567285f3e8f166d95ef4a180a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 22:11:01 -0500 Subject: [PATCH 017/178] add __repr__ --- pytato/array.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytato/array.py b/pytato/array.py index 3ed376caf..37cc653bf 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1696,6 +1696,9 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: for f in fields(self): key_builder.rec(key_hash, getattr(self, f.name)) + def __repr__(self) -> str: + return f"{self.filename}:{self.lineno}, in {self.name}: {self.line}" + @dataclass(frozen=True, eq=True) class _PytatoStackSummary: @@ -1721,6 +1724,9 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: for f in fields(self): key_builder.rec(key_hash, getattr(self, f.name)) + def __repr__(self) -> str: + return "\n " + "\n ".join([str(f) for f in self.frames]) + def _get_default_tags() -> TagsType: import traceback From 07b2fa16c92f5b9b9c2c923ed65dd0e61e502e6d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 23:00:20 -0500 Subject: [PATCH 018/178] fix 2 tests --- pytato/transform.py | 2 +- test/test_pytato.py | 28 ++++++++++++++++++++++++---- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/pytato/transform.py b/pytato/transform.py index fe39e1243..adb9f1346 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -1042,7 +1042,7 @@ def process_node(expr: ArrayOrNames) -> ArrayOrNames: elif isinstance(expr, AbstractResultWithNamedArrays): return expr else: - raise AssertionError() + raise AssertionError(type(expr)) return map_and_copy(expr, process_node) diff --git a/test/test_pytato.py b/test/test_pytato.py index bd4dbdd26..f227c8a75 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -297,6 +297,14 @@ def test_dict_of_named_arrays_comparison(): dict2 = pt.make_dict_of_named_arrays({"out": 2 * x}) dict3 = pt.make_dict_of_named_arrays({"not_out": 2 * x}) dict4 = pt.make_dict_of_named_arrays({"out": 3 * x}) + + from pytato.transform import remove_tags_of_type + from pytato.tags import CreatedAt + dict1 = cast(pt.Array, remove_tags_of_type(CreatedAt, dict1)) + dict2 = cast(pt.Array, remove_tags_of_type(CreatedAt, dict2)) + dict3 = cast(pt.Array, remove_tags_of_type(CreatedAt, dict3)) + dict4 = cast(pt.Array, remove_tags_of_type(CreatedAt, dict4)) + assert dict1 == dict2 assert dict1 != dict3 assert dict1 != dict4 @@ -626,10 +634,22 @@ def test_rec_get_user_nodes(): expr = pt.make_dict_of_named_arrays({"out1": 2 * x1, "out2": 7 * x1 + 3 * x2}) - assert (pt.transform.rec_get_user_nodes(expr, x1) - == frozenset({2 * x1, 7*x1, 7*x1 + 3 * x2, expr})) - assert (pt.transform.rec_get_user_nodes(expr, x2) - == frozenset({3 * x2, 7*x1 + 3 * x2, expr})) + t1 = pt.transform.rec_get_user_nodes(expr, x1) + t1r = frozenset({2 * x1, 7*x1, 7*x1 + 3 * x2, expr}) + + t2 = pt.transform.rec_get_user_nodes(expr, x2) + t2r = frozenset({3 * x2, 7*x1 + 3 * x2, expr}) + + from pytato.transform import remove_tags_of_type + from pytato.tags import CreatedAt + + t1 = frozenset({remove_tags_of_type(CreatedAt, t) for t in t1}) + t1r = frozenset({remove_tags_of_type(CreatedAt, t) for t in t1r}) + t2 = frozenset({remove_tags_of_type(CreatedAt, t) for t in t2}) + t2r = frozenset({remove_tags_of_type(CreatedAt, t) for t in t2r}) + + assert (t1 == t1r) + assert (t2 == t2r) def test_rec_get_user_nodes_linear_complexity(): From 437954cbb203793890c16616bbcabab0757f04a2 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 23:24:00 -0500 Subject: [PATCH 019/178] illustrate test failure with construct_intestine_graph --- test/test_pytato.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index f227c8a75..ac8fd106b 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -707,9 +707,29 @@ def post_visit(self, expr): expected_result[expr] = {"foo"} expr, inp = construct_intestine_graph() + + from pytato.transform import remove_tags_of_type + from pytato.tags import CreatedAt + # node_to_users = remove_tags_of_type(CreatedAt, user_collector.node_to_users) + + node_to_users = {} + + for k in user_collector.node_to_users.keys(): + new_key = remove_tags_of_type(CreatedAt, k) + new_values = set({remove_tags_of_type(CreatedAt, v) for v in user_collector.node_to_users[k]}) + + node_to_users[new_key] = new_values + + + result = pt.transform.tag_user_nodes(user_collector.node_to_users, "foo", inp) ExpectedResultComputer()(expr) + import pudb + pu.db + + + assert expected_result == result From f05592e87cea001c195714fe4a7c59deac0fb28f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 29 Mar 2022 10:44:55 -0500 Subject: [PATCH 020/178] shorten traceback printing --- pytato/array.py | 28 +++++++++++++++++++++++++++- pytato/tags.py | 3 +++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index 37cc653bf..08e7e4bea 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1697,7 +1697,7 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: key_builder.rec(key_hash, getattr(self, f.name)) def __repr__(self) -> str: - return f"{self.filename}:{self.lineno}, in {self.name}: {self.line}" + return f"{self.filename}:{self.lineno}, in {self.name}(): {self.line}" @dataclass(frozen=True, eq=True) @@ -1724,6 +1724,32 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: for f in fields(self): key_builder.rec(key_hash, getattr(self, f.name)) + def __str__(self) -> str: + from os.path import dirname + + res = None + + # Find the first file in the frames that it is not in pytato's pytato/ + # directory. + for idx, frame in enumerate(reversed(self.frames)): + frame_dir = dirname(frame.filename) + if not frame_dir.endswith("pytato"): + res = str(frame) + + # Indicate whether frames were omitted + if idx < len(self.frames)-1: + res += " ..." + if idx > 0: + res = "... " + res + break + + if not res: + # Fallback in case we don't find any file that is not in the pytato/ + # directory (should be unlikely). + return self.__repr__() + + return res + def __repr__(self) -> str: return "\n " + "\n ".join([str(f) for f in self.frames]) diff --git a/pytato/tags.py b/pytato/tags.py index 77d2fe218..36b28d588 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -116,3 +116,6 @@ class CreatedAt(UniqueTag): """ traceback: _PytatoStackSummary + + def __repr__(self) -> str: + return "CreatedAt(" + str(self.traceback) + ")" From d0409968304c553f7c62a28463f74881e4e8a4c3 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 30 Mar 2022 11:38:35 -0500 Subject: [PATCH 021/178] use separate field for CreatedAt --- pytato/array.py | 32 ++++++++++++++------------------ pytato/visualization.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 43 insertions(+), 19 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 08e7e4bea..074775b5d 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1696,6 +1696,14 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: for f in fields(self): key_builder.rec(key_hash, getattr(self, f.name)) + def short_str(self) -> str: + s = f"{self.filename}:{self.lineno}, in {self.name}():\n{self.line}" + s1, s2 = s.split("\n") + # Limit display to 35 characters + s1 = "[...] " + s1[len(s1)-35:] if len(s1) > 35 else s1 + s2 = s2[:35] + " [...]" if len(s2) > 35 else s2 + return s1 + "\n" + s2 + def __repr__(self) -> str: return f"{self.filename}:{self.lineno}, in {self.name}(): {self.line}" @@ -1724,31 +1732,19 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: for f in fields(self): key_builder.rec(key_hash, getattr(self, f.name)) - def __str__(self) -> str: + def short_str(self) -> str: from os.path import dirname - res = None - # Find the first file in the frames that it is not in pytato's pytato/ # directory. - for idx, frame in enumerate(reversed(self.frames)): + for frame in reversed(self.frames): frame_dir = dirname(frame.filename) if not frame_dir.endswith("pytato"): - res = str(frame) - - # Indicate whether frames were omitted - if idx < len(self.frames)-1: - res += " ..." - if idx > 0: - res = "... " + res - break - - if not res: - # Fallback in case we don't find any file that is not in the pytato/ - # directory (should be unlikely). - return self.__repr__() + return frame.short_str() - return res + # Fallback in case we don't find any file that is not in the pytato/ + # directory (should be unlikely). + return self.__repr__() def __repr__(self) -> str: return "\n " + "\n ".join([str(f) for f in self.frames]) diff --git a/pytato/visualization.py b/pytato/visualization.py index 573aa1d04..d755afe6b 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -73,13 +73,40 @@ class DotNodeInfo: edges: Dict[str, ArrayOrNames] +def stringify_created_at(tags: TagsType) -> str: + from pytato.tags import CreatedAt + for tag in tags: + if isinstance(tag, CreatedAt): + return tag.traceback.short_str() + + return "" + + def stringify_tags(tags: TagsType) -> str: + # The CreatedAt tag is handled in stringify_created_at() + from pytato.tags import CreatedAt + tags = set(tag for tag in tags if not isinstance(tag, CreatedAt)) + components = sorted(str(elem) for elem in tags) return "{" + ", ".join(components) + "}" def stringify_shape(shape: ShapeType) -> str: - components = [str(elem) for elem in shape] + from pytato.tags import CreatedAt + from pytato import SizeParam + + new_elems = set() + for elem in shape: + # Remove CreatedAt tags from SizeParam + if isinstance(elem, SizeParam): + new_elem = elem.copy( + tags=frozenset(tag for tag in elem.tags + if not isinstance(tag, CreatedAt))) + new_elems.add(new_elem) + else: + new_elems.add(elem) + + components = [str(elem) for elem in new_elems] if not components: components = [","] elif len(components) == 1: @@ -95,6 +122,7 @@ def __init__(self) -> None: def get_common_dot_info(self, expr: Array) -> DotNodeInfo: title = type(expr).__name__ fields = dict(addr=hex(id(expr)), + created_at=stringify_created_at(expr.tags), shape=stringify_shape(expr.shape), dtype=str(expr.dtype), tags=stringify_tags(expr.tags)) From 9fdd602a66f82f4be67175bc85bde4289f318eab Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 30 Mar 2022 13:09:39 -0500 Subject: [PATCH 022/178] fix tests --- pytato/visualization.py | 10 +++++----- test/test_pytato.py | 38 +++++++++++++++++--------------------- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/pytato/visualization.py b/pytato/visualization.py index d755afe6b..68b397067 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -85,7 +85,7 @@ def stringify_created_at(tags: TagsType) -> str: def stringify_tags(tags: TagsType) -> str: # The CreatedAt tag is handled in stringify_created_at() from pytato.tags import CreatedAt - tags = set(tag for tag in tags if not isinstance(tag, CreatedAt)) + tags = frozenset(tag for tag in tags if not isinstance(tag, CreatedAt)) components = sorted(str(elem) for elem in tags) return "{" + ", ".join(components) + "}" @@ -97,14 +97,14 @@ def stringify_shape(shape: ShapeType) -> str: new_elems = set() for elem in shape: - # Remove CreatedAt tags from SizeParam - if isinstance(elem, SizeParam): + if not isinstance(elem, SizeParam): + new_elems.add(elem) + else: + # Remove CreatedAt tags from SizeParam new_elem = elem.copy( tags=frozenset(tag for tag in elem.tags if not isinstance(tag, CreatedAt))) new_elems.add(new_elem) - else: - new_elems.add(elem) components = [str(elem) for elem in new_elems] if not components: diff --git a/test/test_pytato.py b/test/test_pytato.py index ac8fd106b..38f937229 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -399,7 +399,7 @@ def test_linear_complexity_inequality(): from pytato.equality import EqualityComparer from numpy.random import default_rng - def construct_intestine_graph(depth=100, seed=0): + def construct_intestine_graph(depth=90, seed=0): rng = default_rng(seed) x = pt.make_placeholder("x", shape=(10,), dtype=float) @@ -413,6 +413,13 @@ def construct_intestine_graph(depth=100, seed=0): graph2 = construct_intestine_graph() graph3 = construct_intestine_graph(seed=3) + from pytato.transform import remove_tags_of_type + from pytato.tags import CreatedAt + + graph1 = remove_tags_of_type(CreatedAt, graph1) + graph2 = remove_tags_of_type(CreatedAt, graph2) + graph3 = remove_tags_of_type(CreatedAt, graph3) + assert EqualityComparer()(graph1, graph2) assert EqualityComparer()(graph2, graph1) assert not EqualityComparer()(graph1, graph3) @@ -685,7 +692,7 @@ def post_visit(self, expr): def test_tag_user_nodes_linear_complexity(): from numpy.random import default_rng - def construct_intestine_graph(depth=100, seed=0): + def construct_intestine_graph(depth=90, seed=0): rng = default_rng(seed) x = pt.make_placeholder("x", shape=(10,), dtype=float) y = x @@ -696,7 +703,13 @@ def construct_intestine_graph(depth=100, seed=0): return y, x + from pytato.transform import remove_tags_of_type + from pytato.tags import CreatedAt + expr, inp = construct_intestine_graph() + expr = remove_tags_of_type(CreatedAt, expr) + inp = remove_tags_of_type(CreatedAt, inp) + user_collector = pt.transform.UsersCollector() user_collector(expr) @@ -707,29 +720,12 @@ def post_visit(self, expr): expected_result[expr] = {"foo"} expr, inp = construct_intestine_graph() - - from pytato.transform import remove_tags_of_type - from pytato.tags import CreatedAt - # node_to_users = remove_tags_of_type(CreatedAt, user_collector.node_to_users) - - node_to_users = {} - - for k in user_collector.node_to_users.keys(): - new_key = remove_tags_of_type(CreatedAt, k) - new_values = set({remove_tags_of_type(CreatedAt, v) for v in user_collector.node_to_users[k]}) - - node_to_users[new_key] = new_values - - + expr = remove_tags_of_type(CreatedAt, expr) + inp = remove_tags_of_type(CreatedAt, inp) result = pt.transform.tag_user_nodes(user_collector.node_to_users, "foo", inp) ExpectedResultComputer()(expr) - import pudb - pu.db - - - assert expected_result == result From e606a4823264a936d04513cce99dddfde033a1d3 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 30 Mar 2022 15:24:57 -0500 Subject: [PATCH 023/178] fix doctest --- pytato/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytato/utils.py b/pytato/utils.py index fe48c1368..8bd32139d 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -242,6 +242,8 @@ def dim_to_index_lambda_components(expr: ShapeComponent, .. testsetup:: >>> import pytato as pt + >>> from pytato.transform import remove_tags_of_type + >>> from pytato.tags import CreatedAt >>> from pytato.utils import dim_to_index_lambda_components >>> from pytools import UniqueNameGenerator @@ -251,7 +253,7 @@ def dim_to_index_lambda_components(expr: ShapeComponent, >>> expr, bnds = dim_to_index_lambda_components(3*n+8, UniqueNameGenerator()) >>> print(expr) 3*_in + 8 - >>> bnds + >>> {"_in": remove_tags_of_type(CreatedAt, bnds["_in"])} {'_in': SizeParam(name='n')} """ if isinstance(expr, INT_CLASSES): From 235d9a72c009d8765bed92fc8130fd6dbad79f94 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 30 Mar 2022 17:25:35 -0500 Subject: [PATCH 024/178] make it a tag again --- pytato/tags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/tags.py b/pytato/tags.py index 36b28d588..6c6f773f6 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -109,7 +109,7 @@ class AssumeNonNegative(Tag): # https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues # on why this can not be '@tag_dataclass'. @dataclass(init=True, eq=True, frozen=True, repr=True) -class CreatedAt(UniqueTag): +class CreatedAt(Tag): """ A tag attached to a :class:`~pytato.Array` to store the traceback of where it was created. From d80066f45b80130f54e105f41a5f1520187d0381 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 1 Apr 2022 10:36:00 -0500 Subject: [PATCH 025/178] use tooltip instead of table row --- pytato/array.py | 12 ++++++------ pytato/visualization.py | 11 +++++++---- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 074775b5d..e6546edb7 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1696,12 +1696,12 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: for f in fields(self): key_builder.rec(key_hash, getattr(self, f.name)) - def short_str(self) -> str: + def short_str(self, maxlen: int = 100) -> str: s = f"{self.filename}:{self.lineno}, in {self.name}():\n{self.line}" s1, s2 = s.split("\n") - # Limit display to 35 characters - s1 = "[...] " + s1[len(s1)-35:] if len(s1) > 35 else s1 - s2 = s2[:35] + " [...]" if len(s2) > 35 else s2 + # Limit display to maxlen characters + s1 = "[...] " + s1[len(s1)-maxlen:] if len(s1) > maxlen else s1 + s2 = s2[:maxlen] + " [...]" if len(s2) > maxlen else s2 return s1 + "\n" + s2 def __repr__(self) -> str: @@ -1732,7 +1732,7 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: for f in fields(self): key_builder.rec(key_hash, getattr(self, f.name)) - def short_str(self) -> str: + def short_str(self, maxlen: int = 100) -> str: from os.path import dirname # Find the first file in the frames that it is not in pytato's pytato/ @@ -1740,7 +1740,7 @@ def short_str(self) -> str: for frame in reversed(self.frames): frame_dir = dirname(frame.filename) if not frame_dir.endswith("pytato"): - return frame.short_str() + return frame.short_str(maxlen) # Fallback in case we don't find any file that is not in the pytato/ # directory (should be unlikely). diff --git a/pytato/visualization.py b/pytato/visualization.py index 68b397067..498a0d58a 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -293,8 +293,10 @@ def _emit_array(emit: DotEmitter, title: str, fields: Dict[str, str], td_attrib = 'border="0"' table_attrib = 'border="0" cellborder="1" cellspacing="0"' - rows = ['%s' - % (td_attrib, dot_escape(title))] + rows = [f"{dot_escape(title)}"] + + created_at = fields.pop("created_at", "") + tooltip = dot_escape(created_at) for name, field in fields.items(): field_content = dot_escape(field).replace("\n", "
") @@ -302,8 +304,9 @@ def _emit_array(emit: DotEmitter, title: str, fields: Dict[str, str], f"{dot_escape(name)}:" f"{field_content}" ) - table = "\n%s
" % (table_attrib, "".join(rows)) - emit("%s [label=<%s> style=filled fillcolor=%s]" % (dot_node_id, table, color)) + table = f"\n{''.join(rows)}
" + emit(f"{dot_node_id} [label=<{table}> style=filled fillcolor={color} " + f'tooltip="{tooltip}"]') def _emit_name_cluster(emit: DotEmitter, names: Mapping[str, ArrayOrNames], From ef3339f906224f6fefbc04834e0dff6c534191e5 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 1 Apr 2022 11:21:04 -0500 Subject: [PATCH 026/178] force openmpi usage --- .test-conda-env-py3.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.test-conda-env-py3.yml b/.test-conda-env-py3.yml index 6459d4c06..7917877fb 100644 --- a/.test-conda-env-py3.yml +++ b/.test-conda-env-py3.yml @@ -12,3 +12,4 @@ dependencies: - islpy - sphinx-autodoc-typehints - mpi4py +- openmpi From 1ff1a2b49f8c867bab87d00907ba3271a11a048e Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 1 Apr 2022 15:00:04 -0500 Subject: [PATCH 027/178] check for existing CreatedAt and make it a UniqueTag again --- pytato/array.py | 14 ++++++++------ pytato/tags.py | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index e6546edb7..9a18f4081 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1735,7 +1735,7 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: def short_str(self, maxlen: int = 100) -> str: from os.path import dirname - # Find the first file in the frames that it is not in pytato's pytato/ + # Find the first file in the frames that is not in pytato's pytato/ # directory. for frame in reversed(self.frames): frame_dir = dirname(frame.filename) @@ -1750,11 +1750,13 @@ def __repr__(self) -> str: return "\n " + "\n ".join([str(f) for f in self.frames]) -def _get_default_tags() -> TagsType: +def _get_default_tags(existing_tags: Optional[TagsType] = None) -> TagsType: import traceback from pytato.tags import CreatedAt - if __debug__: + if __debug__ and ( + existing_tags is None + or not any(isinstance(tag, CreatedAt) for tag in existing_tags)): frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) for s in traceback.extract_stack()) c = CreatedAt(_PytatoStackSummary(frames)) @@ -2015,7 +2017,7 @@ def make_placeholder(name: str, f" expected {len(shape)}, got {len(axes)}.") return Placeholder(name, shape, dtype, axes=axes, - tags=(tags | _get_default_tags())) + tags=(tags | _get_default_tags(tags))) def make_size_param(name: str, @@ -2029,7 +2031,7 @@ def make_size_param(name: str, :param tags: implementation tags """ _check_identifier(name, optional=False) - return SizeParam(name, tags=(tags | _get_default_tags())) + return SizeParam(name, tags=(tags | _get_default_tags(tags))) def make_data_wrapper(data: DataInterface, @@ -2059,7 +2061,7 @@ def make_data_wrapper(data: DataInterface, return DataWrapper(name, data, shape, axes=axes, - tags=(tags | _get_default_tags())) + tags=(tags | _get_default_tags(tags))) # }}} diff --git a/pytato/tags.py b/pytato/tags.py index 6c6f773f6..36b28d588 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -109,7 +109,7 @@ class AssumeNonNegative(Tag): # https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues # on why this can not be '@tag_dataclass'. @dataclass(init=True, eq=True, frozen=True, repr=True) -class CreatedAt(Tag): +class CreatedAt(UniqueTag): """ A tag attached to a :class:`~pytato.Array` to store the traceback of where it was created. From c57a4a17b2757e74cff63fd54ec806dc6acb594d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 16 May 2022 10:50:51 -0500 Subject: [PATCH 028/178] flake8 --- pytato/visualization.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytato/visualization.py b/pytato/visualization.py index 91af34607..d5693ed5a 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -31,10 +31,9 @@ import html from typing import (TYPE_CHECKING, Callable, Dict, Union, Iterator, List, - Mapping, Hashable, Any, FrozenSet) + Mapping, Hashable, Any) from pytools import UniqueNameGenerator -from pytools.tag import Tag from pytools.codegen import CodeGenerator as CodeGeneratorBase from pytato.loopy import LoopyCall From 4ae31b1ef8416e152899fe5f0807a8dc560655a1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 16 May 2022 10:54:53 -0500 Subject: [PATCH 029/178] add simple equality test --- test/test_pytato.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index 28f1a723e..f2ad7d9b4 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -844,6 +844,10 @@ def test_created_at(): b = pt.make_placeholder("b", (10, 10), "float64") res = a+b + res2 = a+b + + # CreatedAt tags need to be filtered for equality to work correctly. + assert res == res2 from pytato.tags import CreatedAt From f559a59bf35a1d77a7445c3ca5b8bb633aa64849 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 16 May 2022 13:17:32 -0500 Subject: [PATCH 030/178] lint fixes --- pytato/array.py | 9 +++++---- pytato/visualization.py | 7 ++++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 36e5de2e9..2ed902a46 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -258,7 +258,7 @@ def normalize_shape_component( # }}} -# {{{ array inteface +# {{{ array interface ConvertibleToIndexExpr = Union[int, slice, "Array", None, EllipsisType] IndexExpr = Union[IntegralT, "NormalizedSlice", "Array", None, EllipsisType] @@ -1682,9 +1682,9 @@ def _get_default_axes(ndim: int) -> AxesT: class _PytatoFrameSummary: """Class to store a single call frame.""" filename: str - lineno: int + lineno: Optional[int] name: str - line: str + line: Optional[str] def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: key_builder.rec(key_hash, @@ -1752,7 +1752,8 @@ def __repr__(self) -> str: return "\n " + "\n ".join([str(f) for f in self.frames]) -def _get_default_tags(existing_tags: Optional[TagsType] = None) -> TagsType: +def _get_default_tags(existing_tags: Optional[FrozenSet[Tag]] = None) \ + -> FrozenSet[Tag]: import traceback from pytato.tags import CreatedAt diff --git a/pytato/visualization.py b/pytato/visualization.py index d5693ed5a..92215c44d 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -31,7 +31,7 @@ import html from typing import (TYPE_CHECKING, Callable, Dict, Union, Iterator, List, - Mapping, Hashable, Any) + Mapping, Hashable, Any, FrozenSet) from pytools import UniqueNameGenerator from pytools.codegen import CodeGenerator as CodeGeneratorBase @@ -44,6 +44,7 @@ from pytato.codegen import normalize_outputs from pytato.transform import CachedMapper, ArrayOrNames +from pytools.tag import Tag from pytato.partition import GraphPartition from pytato.distributed import DistributedGraphPart @@ -72,7 +73,7 @@ class DotNodeInfo: edges: Dict[str, ArrayOrNames] -def stringify_created_at(tags: TagsType) -> str: +def stringify_created_at(tags: FrozenSet[Tag]) -> str: from pytato.tags import CreatedAt for tag in tags: if isinstance(tag, CreatedAt): @@ -81,7 +82,7 @@ def stringify_created_at(tags: TagsType) -> str: return "" -def stringify_tags(tags: TagsType) -> str: +def stringify_tags(tags: FrozenSet[Tag]) -> str: # The CreatedAt tag is handled in stringify_created_at() from pytato.tags import CreatedAt tags = frozenset(tag for tag in tags if not isinstance(tag, CreatedAt)) From 0794bdb4fd986549bce4ed6f463e31f34f995dec Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 16 May 2022 13:51:00 -0500 Subject: [PATCH 031/178] add InfoTag class and filter tags based on it --- pytato/array.py | 7 ++++++- pytato/equality.py | 46 +++++++++++++++++++++++++++++++--------------- pytato/tags.py | 7 ++++++- 3 files changed, 43 insertions(+), 17 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 2ed902a46..a257da153 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -480,9 +480,12 @@ def T(self) -> Array: @memoize_method def __hash__(self) -> int: + from pytato.equality import preprocess_tags_for_equality attrs = [] for field in self._fields: attr = getattr(self, field) + if field == "tags": + attr = preprocess_tags_for_equality(attr) if isinstance(attr, dict): attr = frozenset(attr.items()) attrs.append(attr) @@ -1590,7 +1593,9 @@ def __init__(self, self._shape = shape def __hash__(self) -> int: - return id(self) + from pytato.equality import preprocess_tags_for_equality + return hash((self.name, id(self.data), self._shape, self.axes, + preprocess_tags_for_equality(self.tags), self.axes)) def __eq__(self, other: Any) -> bool: return self is other diff --git a/pytato/equality.py b/pytato/equality.py index 984d1f712..8f438b791 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -24,7 +24,7 @@ THE SOFTWARE. """ -from typing import Any, Callable, Dict, TYPE_CHECKING, Tuple, Union +from typing import Any, Callable, Dict, TYPE_CHECKING, Tuple, Union, FrozenSet from pytato.array import (AdvancedIndexInContiguousAxes, AdvancedIndexInNoncontiguousAxes, AxisPermutation, BasicIndex, Concatenate, DataWrapper, Einsum, @@ -32,11 +32,16 @@ Reshape, Roll, Stack, AbstractResultWithNamedArrays, Array, DictOfNamedArrays, Placeholder, SizeParam) +from pytools.tag import Tag +from pytato.tags import InfoTag + if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult from pytato.distributed import DistributedRecv, DistributedSendRefHolder + __doc__ = """ +.. autofunction:: preprocess_tags_for_equality .. autoclass:: EqualityComparer """ @@ -44,6 +49,13 @@ ArrayOrNames = Union[Array, AbstractResultWithNamedArrays] +def preprocess_tags_for_equality(tags: FrozenSet[Tag]) -> FrozenSet[Tag]: + """Remove tags of :class:`InfoTag` for equality comparison.""" + return frozenset(tag + for tag in tags + if not isinstance(tag, InfoTag)) + + # {{{ EqualityComparer class EqualityComparer: @@ -95,6 +107,10 @@ def handle_unsupported_array(self, expr1: Array, expr2: Any) -> bool: raise NotImplementedError(type(expr1).__name__) + def are_tags_equal(self, tags1: FrozenSet[Tag], tags2: FrozenSet[Tag]) -> bool: + return (preprocess_tags_for_equality(tags1) + == preprocess_tags_for_equality(tags2)) + def map_foreign(self, expr1: Any, expr2: Any) -> bool: raise NotImplementedError(type(expr1).__name__) @@ -103,14 +119,14 @@ def map_placeholder(self, expr1: Placeholder, expr2: Any) -> bool: and expr1.name == expr2.name and expr1.shape == expr2.shape and expr1.dtype == expr2.dtype - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) def map_size_param(self, expr1: SizeParam, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.name == expr2.name - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) @@ -129,7 +145,7 @@ def map_index_lambda(self, expr1: IndexLambda, expr2: Any) -> bool: if isinstance(dim1, Array) else dim1 == dim2 for dim1, dim2 in zip(expr1.shape, expr2.shape)) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes) def map_stack(self, expr1: Stack, expr2: Any) -> bool: @@ -138,7 +154,7 @@ def map_stack(self, expr1: Stack, expr2: Any) -> bool: and len(expr1.arrays) == len(expr2.arrays) and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.arrays, expr2.arrays)) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) @@ -148,7 +164,7 @@ def map_concatenate(self, expr1: Concatenate, expr2: Any) -> bool: and len(expr1.arrays) == len(expr2.arrays) and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.arrays, expr2.arrays)) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) @@ -156,7 +172,7 @@ def map_roll(self, expr1: Roll, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.axis == expr2.axis and self.rec(expr1.array, expr2.array) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) @@ -164,7 +180,7 @@ def map_axis_permutation(self, expr1: AxisPermutation, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.axis_permutation == expr2.axis_permutation and self.rec(expr1.array, expr2.array) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) @@ -177,7 +193,7 @@ def _map_index_base(self, expr1: IndexBase, expr2: Any) -> bool: and isinstance(idx2, Array)) else idx1 == idx2 for idx1, idx2 in zip(expr1.indices, expr2.indices)) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) @@ -200,7 +216,7 @@ def map_reshape(self, expr1: Reshape, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.newshape == expr2.newshape and self.rec(expr1.array, expr2.array) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) @@ -210,14 +226,14 @@ def map_einsum(self, expr1: Einsum, expr2: Any) -> bool: and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.args, expr2.args)) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) def map_named_array(self, expr1: NamedArray, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and self.rec(expr1._container, expr2._container) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes and expr1.name == expr2.name) @@ -236,7 +252,7 @@ def map_loopy_call(self, expr1: LoopyCall, expr2: Any) -> bool: def map_loopy_call_result(self, expr1: LoopyCallResult, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and self.rec(expr1._container, expr2._container) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes and expr1.name == expr2.name) @@ -254,7 +270,7 @@ def map_distributed_send_ref_holder( and expr1.send.dest_rank == expr2.send.dest_rank and expr1.send.comm_tag == expr2.send.comm_tag and expr1.send.tags == expr2.send.tags - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) ) def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: @@ -263,7 +279,7 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: and expr1.comm_tag == expr2.comm_tag and expr1.shape == expr2.shape and expr1.dtype == expr2.dtype - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) ) # }}} diff --git a/pytato/tags.py b/pytato/tags.py index 36b28d588..b8e0ad7bd 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -105,11 +105,16 @@ class AssumeNonNegative(Tag): """ +class InfoTag(Tag): + """A type of tag whose value is purely informational and should not be used + for equality comparison.""" + + # See # https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues # on why this can not be '@tag_dataclass'. @dataclass(init=True, eq=True, frozen=True, repr=True) -class CreatedAt(UniqueTag): +class CreatedAt(UniqueTag, InfoTag): """ A tag attached to a :class:`~pytato.Array` to store the traceback of where it was created. From 43c83ec2dfc2007537af12227ca56e2f31be5e60 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 16 May 2022 14:20:41 -0500 Subject: [PATCH 032/178] fix doc --- pytato/equality.py | 2 +- pytato/tags.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pytato/equality.py b/pytato/equality.py index 8f438b791..0dacbbb9f 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -50,7 +50,7 @@ def preprocess_tags_for_equality(tags: FrozenSet[Tag]) -> FrozenSet[Tag]: - """Remove tags of :class:`InfoTag` for equality comparison.""" + """Remove tags of :class:`~pytato.tags.InfoTag` for equality comparison.""" return frozenset(tag for tag in tags if not isinstance(tag, InfoTag)) diff --git a/pytato/tags.py b/pytato/tags.py index b8e0ad7bd..0e480a273 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -9,6 +9,8 @@ .. autoclass:: Named .. autoclass:: PrefixNamed .. autoclass:: AssumeNonNegative +.. autoclass:: InfoTag +.. autoclass:: CreatedAt """ From c09bbf39d3665122b31afe612f1cc13a62b683c3 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 16 May 2022 17:47:08 -0500 Subject: [PATCH 033/178] another doc fix --- pytato/array.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytato/array.py b/pytato/array.py index a257da153..c0285a909 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -142,6 +142,9 @@ .. autoclass:: ReductionAxis .. autoclass:: NormalizedSlice +.. autoclass:: _PytatoFrameSummary +.. autoclass:: _PytatoStackSummary + Internal stuff that is only here because the documentation tool wants it ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From cd67d684d1c777bc3efef080031446e033e75f42 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 16 May 2022 21:20:21 -0500 Subject: [PATCH 034/178] use IgnoredForEqualityTag --- pytato/equality.py | 8 ++++---- pytato/tags.py | 10 ++-------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/pytato/equality.py b/pytato/equality.py index 0dacbbb9f..df48c85be 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -32,8 +32,7 @@ Reshape, Roll, Stack, AbstractResultWithNamedArrays, Array, DictOfNamedArrays, Placeholder, SizeParam) -from pytools.tag import Tag -from pytato.tags import InfoTag +from pytools.tag import Tag, IgnoredForEqualityTag if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult @@ -50,10 +49,11 @@ def preprocess_tags_for_equality(tags: FrozenSet[Tag]) -> FrozenSet[Tag]: - """Remove tags of :class:`~pytato.tags.InfoTag` for equality comparison.""" + """Remove tags of :class:`~pytools.tag.IgnoredForEqualityTag` for equality + comparison.""" return frozenset(tag for tag in tags - if not isinstance(tag, InfoTag)) + if not isinstance(tag, IgnoredForEqualityTag)) # {{{ EqualityComparer diff --git a/pytato/tags.py b/pytato/tags.py index 0e480a273..a51ff14c3 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -9,13 +9,12 @@ .. autoclass:: Named .. autoclass:: PrefixNamed .. autoclass:: AssumeNonNegative -.. autoclass:: InfoTag .. autoclass:: CreatedAt """ from dataclasses import dataclass -from pytools.tag import Tag, UniqueTag, tag_dataclass +from pytools.tag import Tag, UniqueTag, tag_dataclass, IgnoredForEqualityTag from pytato.array import _PytatoStackSummary @@ -107,16 +106,11 @@ class AssumeNonNegative(Tag): """ -class InfoTag(Tag): - """A type of tag whose value is purely informational and should not be used - for equality comparison.""" - - # See # https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues # on why this can not be '@tag_dataclass'. @dataclass(init=True, eq=True, frozen=True, repr=True) -class CreatedAt(UniqueTag, InfoTag): +class CreatedAt(UniqueTag, IgnoredForEqualityTag): """ A tag attached to a :class:`~pytato.Array` to store the traceback of where it was created. From 71dd791b45cee918d33ac8c4cb8e9a3cf643ec37 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 17 May 2022 10:34:51 -0500 Subject: [PATCH 035/178] UNDO BEFORE MERGE: use external project branches --- .github/workflows/ci.yml | 6 +++++- requirements.txt | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 81681f6da..032a42991 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -112,7 +112,11 @@ jobs: curl -L -O https://tiker.net/ci-support-v0 . ./ci-support-v0 - test_downstream "$DOWNSTREAM_PROJECT" + if [[ "$DOWNSTREAM_PROJECT" != "meshmode" ]]; then + test_downstream "$DOWNSTREAM_PROJECT" + else + test_downstream https://github.com/inducer/meshmode@filter_tags + fi if [[ "$DOWNSTREAM_PROJECT" = "meshmode" ]]; then python ../examples/simple-dg.py --lazy diff --git a/requirements.txt b/requirements.txt index a9cf2c76e..65fc57329 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1 +git+https://github.com/matthiasdiener/pytools.git@eq_tag#egg=pytools git+https://github.com/inducer/pymbolic.git#egg=pymbolic git+https://github.com/inducer/genpy.git#egg=genpy git+https://github.com/inducer/loopy.git#egg=loopy From b4f8b82d080433fcaa6e8affe8394b852ed17017 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 18 May 2022 10:04:52 -0500 Subject: [PATCH 036/178] Revert "UNDO BEFORE MERGE: use external project branches" This reverts commit 71dd791b45cee918d33ac8c4cb8e9a3cf643ec37. --- .github/workflows/ci.yml | 6 +----- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 032a42991..81681f6da 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -112,11 +112,7 @@ jobs: curl -L -O https://tiker.net/ci-support-v0 . ./ci-support-v0 - if [[ "$DOWNSTREAM_PROJECT" != "meshmode" ]]; then - test_downstream "$DOWNSTREAM_PROJECT" - else - test_downstream https://github.com/inducer/meshmode@filter_tags - fi + test_downstream "$DOWNSTREAM_PROJECT" if [[ "$DOWNSTREAM_PROJECT" = "meshmode" ]]; then python ../examples/simple-dg.py --lazy diff --git a/requirements.txt b/requirements.txt index 65fc57329..a9cf2c76e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -git+https://github.com/matthiasdiener/pytools.git@eq_tag#egg=pytools +git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1 git+https://github.com/inducer/pymbolic.git#egg=pymbolic git+https://github.com/inducer/genpy.git#egg=genpy git+https://github.com/inducer/loopy.git#egg=loopy From bfb22ba66d0872c3848a4ab41260f9f82ffa188d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 18 May 2022 10:04:59 -0500 Subject: [PATCH 037/178] Revert "use IgnoredForEqualityTag" This reverts commit cd67d684d1c777bc3efef080031446e033e75f42. --- pytato/equality.py | 8 ++++---- pytato/tags.py | 10 ++++++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/pytato/equality.py b/pytato/equality.py index df48c85be..0dacbbb9f 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -32,7 +32,8 @@ Reshape, Roll, Stack, AbstractResultWithNamedArrays, Array, DictOfNamedArrays, Placeholder, SizeParam) -from pytools.tag import Tag, IgnoredForEqualityTag +from pytools.tag import Tag +from pytato.tags import InfoTag if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult @@ -49,11 +50,10 @@ def preprocess_tags_for_equality(tags: FrozenSet[Tag]) -> FrozenSet[Tag]: - """Remove tags of :class:`~pytools.tag.IgnoredForEqualityTag` for equality - comparison.""" + """Remove tags of :class:`~pytato.tags.InfoTag` for equality comparison.""" return frozenset(tag for tag in tags - if not isinstance(tag, IgnoredForEqualityTag)) + if not isinstance(tag, InfoTag)) # {{{ EqualityComparer diff --git a/pytato/tags.py b/pytato/tags.py index a51ff14c3..0e480a273 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -9,12 +9,13 @@ .. autoclass:: Named .. autoclass:: PrefixNamed .. autoclass:: AssumeNonNegative +.. autoclass:: InfoTag .. autoclass:: CreatedAt """ from dataclasses import dataclass -from pytools.tag import Tag, UniqueTag, tag_dataclass, IgnoredForEqualityTag +from pytools.tag import Tag, UniqueTag, tag_dataclass from pytato.array import _PytatoStackSummary @@ -106,11 +107,16 @@ class AssumeNonNegative(Tag): """ +class InfoTag(Tag): + """A type of tag whose value is purely informational and should not be used + for equality comparison.""" + + # See # https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues # on why this can not be '@tag_dataclass'. @dataclass(init=True, eq=True, frozen=True, repr=True) -class CreatedAt(UniqueTag, IgnoredForEqualityTag): +class CreatedAt(UniqueTag, InfoTag): """ A tag attached to a :class:`~pytato.Array` to store the traceback of where it was created. From 99ff0cdaeaa145b1475a05ffd24b4b1ba3a53162 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 18 May 2022 11:14:29 -0500 Subject: [PATCH 038/178] rename InfoTag -> IgnoredForEqualityTag --- pytato/equality.py | 7 ++++--- pytato/tags.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pytato/equality.py b/pytato/equality.py index 0dacbbb9f..9a6145093 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -33,7 +33,7 @@ Array, DictOfNamedArrays, Placeholder, SizeParam) from pytools.tag import Tag -from pytato.tags import InfoTag +from pytato.tags import IgnoredForEqualityTag if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult @@ -50,10 +50,11 @@ def preprocess_tags_for_equality(tags: FrozenSet[Tag]) -> FrozenSet[Tag]: - """Remove tags of :class:`~pytato.tags.InfoTag` for equality comparison.""" + """Remove tags of :class:`~pytato.tags.IgnoredForEqualityTag` for equality + comparison.""" return frozenset(tag for tag in tags - if not isinstance(tag, InfoTag)) + if not isinstance(tag, IgnoredForEqualityTag)) # {{{ EqualityComparer diff --git a/pytato/tags.py b/pytato/tags.py index 0e480a273..f6f9383d7 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -9,7 +9,7 @@ .. autoclass:: Named .. autoclass:: PrefixNamed .. autoclass:: AssumeNonNegative -.. autoclass:: InfoTag +.. autoclass:: IgnoredForEqualityTag .. autoclass:: CreatedAt """ @@ -107,7 +107,7 @@ class AssumeNonNegative(Tag): """ -class InfoTag(Tag): +class IgnoredForEqualityTag(Tag): """A type of tag whose value is purely informational and should not be used for equality comparison.""" @@ -116,7 +116,7 @@ class InfoTag(Tag): # https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues # on why this can not be '@tag_dataclass'. @dataclass(init=True, eq=True, frozen=True, repr=True) -class CreatedAt(UniqueTag, InfoTag): +class CreatedAt(UniqueTag, IgnoredForEqualityTag): """ A tag attached to a :class:`~pytato.Array` to store the traceback of where it was created. From a818694d22faf97c68e599a2ccea851df4ff6151 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 18 May 2022 11:36:53 -0500 Subject: [PATCH 039/178] more stringent tests --- test/test_pytato.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index f2ad7d9b4..cf1c05408 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -846,21 +846,32 @@ def test_created_at(): res = a+b res2 = a+b - # CreatedAt tags need to be filtered for equality to work correctly. + # {{{ Check that CreatedAt tags are filtered correctly for equality + from pytato.equality import preprocess_tags_for_equality + assert res == res2 + assert res.tags != res2.tags + assert (preprocess_tags_for_equality(res.tags) + == preprocess_tags_for_equality(res2.tags)) + + # }}} + from pytato.tags import CreatedAt created_tag = res.tags_of_type(CreatedAt) assert len(created_tag) == 1 + # {{{ Make sure the function name appears in the traceback + tag, = created_tag found = False - # Make sure the function name appears in the traceback - _unused = tag.traceback.to_stacksummary() # noqa + stacksummary = tag.traceback.to_stacksummary() + assert len(stacksummary) > 10 + for frame in tag.traceback.frames: if frame.name == "test_created_at": found = True @@ -868,6 +879,8 @@ def test_created_at(): assert found + # }}} + def test_pickling_and_unpickling_is_equal(): from testlib import RandomDAGContext, make_random_dag From 8a0a773124610a203bcec5bc50e21e959bea9333 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 18 May 2022 11:46:43 -0500 Subject: [PATCH 040/178] undo unnecessary test changes --- pytato/visualization.py | 2 +- test/test_pytato.py | 49 +++++------------------------------------ 2 files changed, 7 insertions(+), 44 deletions(-) diff --git a/pytato/visualization.py b/pytato/visualization.py index 92215c44d..b1535ada9 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -34,6 +34,7 @@ Mapping, Hashable, Any, FrozenSet) from pytools import UniqueNameGenerator +from pytools.tag import Tag from pytools.codegen import CodeGenerator as CodeGeneratorBase from pytato.loopy import LoopyCall @@ -44,7 +45,6 @@ from pytato.codegen import normalize_outputs from pytato.transform import CachedMapper, ArrayOrNames -from pytools.tag import Tag from pytato.partition import GraphPartition from pytato.distributed import DistributedGraphPart diff --git a/test/test_pytato.py b/test/test_pytato.py index cf1c05408..4ea7fb065 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -297,14 +297,6 @@ def test_dict_of_named_arrays_comparison(): dict2 = pt.make_dict_of_named_arrays({"out": 2 * x}) dict3 = pt.make_dict_of_named_arrays({"not_out": 2 * x}) dict4 = pt.make_dict_of_named_arrays({"out": 3 * x}) - - from pytato.transform import remove_tags_of_type - from pytato.tags import CreatedAt - dict1 = cast(pt.Array, remove_tags_of_type(CreatedAt, dict1)) - dict2 = cast(pt.Array, remove_tags_of_type(CreatedAt, dict2)) - dict3 = cast(pt.Array, remove_tags_of_type(CreatedAt, dict3)) - dict4 = cast(pt.Array, remove_tags_of_type(CreatedAt, dict4)) - assert dict1 == dict2 assert dict1 != dict3 assert dict1 != dict4 @@ -413,13 +405,6 @@ def construct_intestine_graph(depth=90, seed=0): graph2 = construct_intestine_graph() graph3 = construct_intestine_graph(seed=3) - from pytato.transform import remove_tags_of_type - from pytato.tags import CreatedAt - - graph1 = remove_tags_of_type(CreatedAt, graph1) - graph2 = remove_tags_of_type(CreatedAt, graph2) - graph3 = remove_tags_of_type(CreatedAt, graph3) - assert EqualityComparer()(graph1, graph2) assert EqualityComparer()(graph2, graph1) assert not EqualityComparer()(graph1, graph3) @@ -479,10 +464,9 @@ def test_array_dot_repr(): x = pt.make_placeholder("x", (10, 4), np.int64) y = pt.make_placeholder("y", (10, 4), np.int64) - from pytato.transform import remove_tags_of_type - from pytato.tags import CreatedAt - def _assert_stripped_repr(ary: pt.Array, expected_repr: str): + from pytato.transform import remove_tags_of_type + from pytato.tags import CreatedAt ary = cast(pt.Array, remove_tags_of_type(CreatedAt, ary)) expected_str = "".join([c for c in repr(ary) if c not in [" ", "\n"]]) @@ -641,22 +625,10 @@ def test_rec_get_user_nodes(): expr = pt.make_dict_of_named_arrays({"out1": 2 * x1, "out2": 7 * x1 + 3 * x2}) - t1 = pt.transform.rec_get_user_nodes(expr, x1) - t1r = frozenset({2 * x1, 7*x1, 7*x1 + 3 * x2, expr}) - - t2 = pt.transform.rec_get_user_nodes(expr, x2) - t2r = frozenset({3 * x2, 7*x1 + 3 * x2, expr}) - - from pytato.transform import remove_tags_of_type - from pytato.tags import CreatedAt - - t1 = frozenset({remove_tags_of_type(CreatedAt, t) for t in t1}) - t1r = frozenset({remove_tags_of_type(CreatedAt, t) for t in t1r}) - t2 = frozenset({remove_tags_of_type(CreatedAt, t) for t in t2}) - t2r = frozenset({remove_tags_of_type(CreatedAt, t) for t in t2r}) - - assert (t1 == t1r) - assert (t2 == t2r) + assert (pt.transform.rec_get_user_nodes(expr, x1) + == frozenset({2 * x1, 7*x1, 7*x1 + 3 * x2, expr})) + assert (pt.transform.rec_get_user_nodes(expr, x2) + == frozenset({3 * x2, 7*x1 + 3 * x2, expr})) def test_rec_get_user_nodes_linear_complexity(): @@ -703,13 +675,7 @@ def construct_intestine_graph(depth=90, seed=0): return y, x - from pytato.transform import remove_tags_of_type - from pytato.tags import CreatedAt - expr, inp = construct_intestine_graph() - expr = remove_tags_of_type(CreatedAt, expr) - inp = remove_tags_of_type(CreatedAt, inp) - user_collector = pt.transform.UsersCollector() user_collector(expr) @@ -720,9 +686,6 @@ def post_visit(self, expr): expected_result[expr] = {"foo"} expr, inp = construct_intestine_graph() - expr = remove_tags_of_type(CreatedAt, expr) - inp = remove_tags_of_type(CreatedAt, inp) - result = pt.transform.tag_user_nodes(user_collector.node_to_users, "foo", inp) ExpectedResultComputer()(expr) From 91fe92f0f78d52e8cf61dbc4568d64f7463645e1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 19 May 2022 15:26:43 -0500 Subject: [PATCH 041/178] Revert "Revert "use IgnoredForEqualityTag"" This reverts commit bfb22ba66d0872c3848a4ab41260f9f82ffa188d. --- pytato/equality.py | 5 ++--- pytato/tags.py | 8 +------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/pytato/equality.py b/pytato/equality.py index 9a6145093..df48c85be 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -32,8 +32,7 @@ Reshape, Roll, Stack, AbstractResultWithNamedArrays, Array, DictOfNamedArrays, Placeholder, SizeParam) -from pytools.tag import Tag -from pytato.tags import IgnoredForEqualityTag +from pytools.tag import Tag, IgnoredForEqualityTag if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult @@ -50,7 +49,7 @@ def preprocess_tags_for_equality(tags: FrozenSet[Tag]) -> FrozenSet[Tag]: - """Remove tags of :class:`~pytato.tags.IgnoredForEqualityTag` for equality + """Remove tags of :class:`~pytools.tag.IgnoredForEqualityTag` for equality comparison.""" return frozenset(tag for tag in tags diff --git a/pytato/tags.py b/pytato/tags.py index f6f9383d7..a51ff14c3 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -9,13 +9,12 @@ .. autoclass:: Named .. autoclass:: PrefixNamed .. autoclass:: AssumeNonNegative -.. autoclass:: IgnoredForEqualityTag .. autoclass:: CreatedAt """ from dataclasses import dataclass -from pytools.tag import Tag, UniqueTag, tag_dataclass +from pytools.tag import Tag, UniqueTag, tag_dataclass, IgnoredForEqualityTag from pytato.array import _PytatoStackSummary @@ -107,11 +106,6 @@ class AssumeNonNegative(Tag): """ -class IgnoredForEqualityTag(Tag): - """A type of tag whose value is purely informational and should not be used - for equality comparison.""" - - # See # https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues # on why this can not be '@tag_dataclass'. From 1111b7915e23561161eee132e8d5ccb716b101eb Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 19 May 2022 16:21:12 -0500 Subject: [PATCH 042/178] simplify condition --- pytato/array.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index c0285a909..8d1d5f2ec 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1765,9 +1765,7 @@ def _get_default_tags(existing_tags: Optional[FrozenSet[Tag]] = None) \ import traceback from pytato.tags import CreatedAt - if __debug__ and ( - existing_tags is None - or not any(isinstance(tag, CreatedAt) for tag in existing_tags)): + if __debug__ and not any(isinstance(tag, CreatedAt) for tag in existing_tags): frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) for s in traceback.extract_stack()) c = CreatedAt(_PytatoStackSummary(frames)) From ff2582f0d2c37d070ecdbf771ea9df61bff6f310 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 19 May 2022 16:31:23 -0500 Subject: [PATCH 043/178] Revert "simplify condition" This reverts commit 1111b7915e23561161eee132e8d5ccb716b101eb. --- pytato/array.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index 8d1d5f2ec..c0285a909 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1765,7 +1765,9 @@ def _get_default_tags(existing_tags: Optional[FrozenSet[Tag]] = None) \ import traceback from pytato.tags import CreatedAt - if __debug__ and not any(isinstance(tag, CreatedAt) for tag in existing_tags): + if __debug__ and ( + existing_tags is None + or not any(isinstance(tag, CreatedAt) for tag in existing_tags)): frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) for s in traceback.extract_stack()) c = CreatedAt(_PytatoStackSummary(frames)) From 26c3590e07313788ad2a1b856c5f7e1474ea07db Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Sat, 21 May 2022 18:02:40 -0500 Subject: [PATCH 044/178] bump pytools version + a few spelling fixes --- pytato/array.py | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index c0285a909..d9d2beeeb 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -355,7 +355,7 @@ class Array(Taggable): :class:`~pytato.array.IndexLambda` is used to produce references to named arrays. Since any array that needs to be referenced in this way needs to obey this restriction anyway, - a decision was made to requir the same of *all* array expressions. + a decision was made to require the same of *all* array expressions. .. attribute:: dtype @@ -677,7 +677,7 @@ def dtype(self) -> np.dtype[Any]: class NamedArray(Array): """An entry in a :class:`AbstractResultWithNamedArrays`. Holds a reference - back to thecontaining instance as well as the name by which *self* is + back to the containing instance as well as the name by which *self* is known there. .. automethod:: __init__ diff --git a/setup.py b/setup.py index 336157d5d..e32d7eca0 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ python_requires="~=3.8", install_requires=[ "loopy>=2020.2", - "pytools>=2021.1", + "pytools>=2022.1.8", "pyrsistent" ], package_data={"pytato": ["py.typed"]}, From 423e3fb3aaa6950d458b360bf984f13e5bac8c1c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Sat, 21 May 2022 18:10:24 -0500 Subject: [PATCH 045/178] remove duplicated self.axes in hash() --- pytato/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index d9d2beeeb..58c6f8ac1 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1598,7 +1598,7 @@ def __init__(self, def __hash__(self) -> int: from pytato.equality import preprocess_tags_for_equality return hash((self.name, id(self.data), self._shape, self.axes, - preprocess_tags_for_equality(self.tags), self.axes)) + preprocess_tags_for_equality(self.tags))) def __eq__(self, other: Any) -> bool: return self is other From 87606ed9569d3d26a36a858316c23f2fcf174840 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 1 Jun 2022 22:32:52 -0500 Subject: [PATCH 046/178] use Taggable{__eq__,__hash__} --- pytato/array.py | 3 +-- pytato/equality.py | 36 ++++++++++++++++-------------------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 2d2839f82..157428b83 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1596,9 +1596,8 @@ def __init__(self, self._shape = shape def __hash__(self) -> int: - from pytato.equality import preprocess_tags_for_equality return hash((self.name, id(self.data), self._shape, self.axes, - preprocess_tags_for_equality(self.tags))) + Taggable.__hash__(self))) def __eq__(self, other: Any) -> bool: return self is other diff --git a/pytato/equality.py b/pytato/equality.py index 4dffab13c..4a99bf4a1 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -32,7 +32,7 @@ Reshape, Roll, Stack, AbstractResultWithNamedArrays, Array, DictOfNamedArrays, Placeholder, SizeParam) -from pytools.tag import Tag, IgnoredForEqualityTag +from pytools.tag import Tag, IgnoredForEqualityTag, Taggable if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult @@ -107,10 +107,6 @@ def handle_unsupported_array(self, expr1: Array, expr2: Any) -> bool: raise NotImplementedError(type(expr1).__name__) - def are_tags_equal(self, tags1: FrozenSet[Tag], tags2: FrozenSet[Tag]) -> bool: - return (preprocess_tags_for_equality(tags1) - == preprocess_tags_for_equality(tags2)) - def map_foreign(self, expr1: Any, expr2: Any) -> bool: raise NotImplementedError(type(expr1).__name__) @@ -119,14 +115,14 @@ def map_placeholder(self, expr1: Placeholder, expr2: Any) -> bool: and expr1.name == expr2.name and expr1.shape == expr2.shape and expr1.dtype == expr2.dtype - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) def map_size_param(self, expr1: SizeParam, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.name == expr2.name - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) @@ -145,7 +141,7 @@ def map_index_lambda(self, expr1: IndexLambda, expr2: Any) -> bool: if isinstance(dim1, Array) else dim1 == dim2 for dim1, dim2 in zip(expr1.shape, expr2.shape)) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes) def map_stack(self, expr1: Stack, expr2: Any) -> bool: @@ -154,7 +150,7 @@ def map_stack(self, expr1: Stack, expr2: Any) -> bool: and len(expr1.arrays) == len(expr2.arrays) and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.arrays, expr2.arrays)) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) @@ -164,7 +160,7 @@ def map_concatenate(self, expr1: Concatenate, expr2: Any) -> bool: and len(expr1.arrays) == len(expr2.arrays) and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.arrays, expr2.arrays)) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) @@ -172,7 +168,7 @@ def map_roll(self, expr1: Roll, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.axis == expr2.axis and self.rec(expr1.array, expr2.array) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) @@ -180,7 +176,7 @@ def map_axis_permutation(self, expr1: AxisPermutation, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.axis_permutation == expr2.axis_permutation and self.rec(expr1.array, expr2.array) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) @@ -193,7 +189,7 @@ def _map_index_base(self, expr1: IndexBase, expr2: Any) -> bool: and isinstance(idx2, Array)) else idx1 == idx2 for idx1, idx2 in zip(expr1.indices, expr2.indices)) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) @@ -216,7 +212,7 @@ def map_reshape(self, expr1: Reshape, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.newshape == expr2.newshape and self.rec(expr1.array, expr2.array) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) @@ -226,14 +222,14 @@ def map_einsum(self, expr1: Einsum, expr2: Any) -> bool: and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.args, expr2.args)) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) def map_named_array(self, expr1: NamedArray, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and self.rec(expr1._container, expr2._container) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes and expr1.name == expr2.name) @@ -252,7 +248,7 @@ def map_loopy_call(self, expr1: LoopyCall, expr2: Any) -> bool: def map_loopy_call_result(self, expr1: LoopyCallResult, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and self.rec(expr1._container, expr2._container) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes and expr1.name == expr2.name) @@ -269,8 +265,8 @@ def map_distributed_send_ref_holder( and self.rec(expr1.passthrough_data, expr2.passthrough_data) and expr1.send.dest_rank == expr2.send.dest_rank and expr1.send.comm_tag == expr2.send.comm_tag - and expr1.send.tags == expr2.send.tags - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1.send, expr2.send) + and Taggable.__eq__(expr1, expr2) ) def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: @@ -279,7 +275,7 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: and expr1.comm_tag == expr2.comm_tag and expr1.shape == expr2.shape and expr1.dtype == expr2.dtype - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) ) # }}} From 73c7f7793e900d5a8ed8a4f03d0cf668be8c5067 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Sat, 4 Jun 2022 15:09:36 +0200 Subject: [PATCH 047/178] add another test --- test/test_pytato.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 67de9461e..961271d53 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -806,23 +806,34 @@ def test_created_at(): a = pt.make_placeholder("a", (10, 10), "float64") b = pt.make_placeholder("b", (10, 10), "float64") - res = a+b + # res1 and res2 are defined on different lines and should have different + # CreatedAt tags. + res1 = a+b res2 = a+b - # {{{ Check that CreatedAt tags are filtered correctly for equality + # res3 and res4 are defined on the same line and should have the same + # CreatedAt tags. + res3 = a+b; res4 = a+b # noqa: E702 + + # {{{ Check that CreatedAt tags are handled correctly for equality + from pytato.equality import preprocess_tags_for_equality - assert res == res2 + assert res1 == res2 == res3 == res4 + + assert res1.tags != res2.tags + assert res3.tags == res4.tags - assert res.tags != res2.tags - assert (preprocess_tags_for_equality(res.tags) + assert (preprocess_tags_for_equality(res1.tags) == preprocess_tags_for_equality(res2.tags)) + assert (preprocess_tags_for_equality(res3.tags) + == preprocess_tags_for_equality(res4.tags)) # }}} from pytato.tags import CreatedAt - created_tag = res.tags_of_type(CreatedAt) + created_tag = res1.tags_of_type(CreatedAt) assert len(created_tag) == 1 From b84b66ef83ac0d31244c6847279900306f977271 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Sat, 4 Jun 2022 15:30:13 +0200 Subject: [PATCH 048/178] add vis test --- test/test_pytato.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 961271d53..2a76f841d 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -847,7 +847,7 @@ def test_created_at(): assert len(stacksummary) > 10 for frame in tag.traceback.frames: - if frame.name == "test_created_at": + if frame.name == "test_created_at" and "a+b" in frame.line: found = True break @@ -855,6 +855,15 @@ def test_created_at(): # }}} + # {{{ Make sure that CreatedAt tags are in the visualization + + from pytato.visualization import get_dot_graph + s = get_dot_graph(res1) + assert "test_created_at" in s + assert "a+b" in s + + # }}} + def test_pickling_and_unpickling_is_equal(): from testlib import RandomDAGContext, make_random_dag From 6c653bf3ae968640135c07fddeb2756737b97953 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 20 Jun 2022 10:32:52 -0500 Subject: [PATCH 049/178] make _PytatoFrameSummary, _PytatoStackSummary undocumented --- pytato/array.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 6b0ede462..fc1f9f94f 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -143,8 +143,13 @@ .. autoclass:: EinsumReductionAxis .. autoclass:: NormalizedSlice -.. autoclass:: _PytatoFrameSummary -.. autoclass:: _PytatoStackSummary +Internal classes for traceback +------------------------------ + +Please consider these undocumented and subject to change at any time. + +.. class:: _PytatoFrameSummary +.. class:: _PytatoStackSummary Internal stuff that is only here because the documentation tool wants it ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From 02dd6f5b353494c507bed3492e9457a91d91323b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 20 Jun 2022 10:38:41 -0500 Subject: [PATCH 050/178] use Taggable.__hash__ for tags in Array.__hash__ --- pytato/array.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index fc1f9f94f..0e3a3908e 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -502,12 +502,11 @@ def T(self) -> Array: @memoize_method def __hash__(self) -> int: - from pytato.equality import preprocess_tags_for_equality attrs = [] for field in self._fields: attr = getattr(self, field) if field == "tags": - attr = preprocess_tags_for_equality(attr) + attr = Taggable.__hash__(self) if isinstance(attr, dict): attr = frozenset(attr.items()) attrs.append(attr) From 7c937079ec4498917b11e8a3b1cf47a120303376 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 28 Mar 2023 15:44:20 -0500 Subject: [PATCH 051/178] change dataclass to attrs --- pytato/array.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index d912ffc68..f6c2e5d69 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1808,7 +1808,7 @@ def _get_default_axes(ndim: int) -> AxesT: return tuple(Axis(frozenset()) for _ in range(ndim)) -@dataclass(frozen=True, eq=True) +@attrs.define(frozen=True, eq=True) class _PytatoFrameSummary: """Class to store a single call frame.""" filename: str @@ -1820,12 +1820,12 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: key_builder.rec(key_hash, (self.__class__.__module__, self.__class__.__qualname__)) - from dataclasses import fields + from attrs import fields # Fields are ordered consistently, so ordered hashing is OK. # # No need to dispatch to superclass: fields() automatically gives us # fields from the entire class hierarchy. - for f in fields(self): + for f in fields(self.__class__): key_builder.rec(key_hash, getattr(self, f.name)) def short_str(self, maxlen: int = 100) -> str: @@ -1840,7 +1840,7 @@ def __repr__(self) -> str: return f"{self.filename}:{self.lineno}, in {self.name}(): {self.line}" -@dataclass(frozen=True, eq=True) +@attrs.define(frozen=True, eq=True) class _PytatoStackSummary: """Class to store a list of :class:`_PytatoFrameSummary` call frames.""" frames: Tuple[_PytatoFrameSummary, ...] @@ -1849,19 +1849,18 @@ def to_stacksummary(self) -> StackSummary: frames = [FrameSummary(f.filename, f.lineno, f.name, line=f.line) for f in self.frames] - # type-ignore-reason: from_list also takes List[FrameSummary] - return StackSummary.from_list(frames) # type: ignore[arg-type] + return StackSummary.from_list(frames) def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: key_builder.rec(key_hash, (self.__class__.__module__, self.__class__.__qualname__)) - from dataclasses import fields + from attrs import fields # Fields are ordered consistently, so ordered hashing is OK. # # No need to dispatch to superclass: fields() automatically gives us # fields from the entire class hierarchy. - for f in fields(self): + for f in fields(self.__class__): key_builder.rec(key_hash, getattr(self, f.name)) def short_str(self, maxlen: int = 100) -> str: From dd9916bdb5571c405c948a9f259368921a22d583 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 28 Mar 2023 16:13:31 -0500 Subject: [PATCH 052/178] flake8 --- pytato/visualization.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytato/visualization.py b/pytato/visualization.py index 413b5495a..dd2448a51 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -122,11 +122,11 @@ def __init__(self) -> None: def get_common_dot_info(self, expr: Array) -> DotNodeInfo: title = type(expr).__name__ fields = {"addr": hex(id(expr)), - "shape": stringify_shape(expr.shape), - "dtype": str(expr.dtype), - "tags": stringify_tags(expr.tags), - "created_at": stringify_created_at(expr.tags), - } + "shape": stringify_shape(expr.shape), + "dtype": str(expr.dtype), + "tags": stringify_tags(expr.tags), + "created_at": stringify_created_at(expr.tags), + } edges: Dict[str, ArrayOrNames] = {} return DotNodeInfo(title, fields, edges) From 61c029cbdc184007efbb3282d2be5388e4441180 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 28 Mar 2023 18:38:53 -0500 Subject: [PATCH 053/178] Taggable.__eq__ --- pytato/equality.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/equality.py b/pytato/equality.py index b381e96ec..eea88a45f 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -247,7 +247,7 @@ def map_loopy_call(self, expr1: LoopyCall, expr2: Any) -> bool: if isinstance(bnd, Array) else bnd == expr2.bindings[name] for name, bnd in expr1.bindings.items()) - and expr1.tags == expr2.tags + and Taggable.__eq__(expr1, expr2) ) def map_loopy_call_result(self, expr1: LoopyCallResult, expr2: Any) -> bool: @@ -262,7 +262,7 @@ def map_dict_of_named_arrays(self, expr1: DictOfNamedArrays, expr2: Any) -> bool and frozenset(expr1._data.keys()) == frozenset(expr2._data.keys()) and all(self.rec(expr1._data[name], expr2._data[name]) for name in expr1._data) - and expr1.tags == expr2.tags + and Taggable.__eq__(expr1, expr2) ) def map_distributed_send_ref_holder( From 3cf1559cdc16d0f4604ef4a3178f35882d3614fc Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 29 Mar 2023 14:22:55 -0500 Subject: [PATCH 054/178] add Array.tagged() --- pytato/array.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index f6c2e5d69..4a0faf0d9 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -180,7 +180,7 @@ import pymbolic.primitives as prim from pymbolic import var from pytools import memoize_method -from pytools.tag import Tag, Taggable +from pytools.tag import Tag, Taggable, ToTagSetConvertible from pytato.scalar_expr import (ScalarType, SCALAR_CLASSES, ScalarExpression, IntegralT, @@ -689,6 +689,12 @@ def __repr__(self) -> str: from pytato.stringifier import Reprifier return Reprifier()(self) + def tagged(self, tags: ToTagSetConvertible) -> Array: + from pytato.equality import preprocess_tags_for_equality + from pytools.tag import normalize_tags + new_tags = preprocess_tags_for_equality(normalize_tags(tags)) + return super().tagged(new_tags) + # }}} From 44d1c34acd2795feac052c6c4f4cc7a3764f3394 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 19 May 2023 14:11:41 -0500 Subject: [PATCH 055/178] restrict to DEBUG_ENABLED --- pytato/array.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index e02f5e1c1..8ea7f9270 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1893,7 +1893,11 @@ def _get_default_tags(existing_tags: Optional[FrozenSet[Tag]] = None) \ import traceback from pytato.tags import CreatedAt - if __debug__ and ( + from pytato import DEBUG_ENABLED + + # This has a significant overhead, so only enable it when PYTATO_DEBUG is + # enabled. + if DEBUG_ENABLED and ( existing_tags is None or not any(isinstance(tag, CreatedAt) for tag in existing_tags)): frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) From a150c79e603a4a6bc9ac79bf26cea703fdcfb04f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 19 May 2023 15:26:49 -0500 Subject: [PATCH 056/178] force DEBUG_ENABLED for test --- test/test_pytato.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index 80aaf2c9a..bb2a46549 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -836,6 +836,9 @@ def test_created_at(): a = pt.make_placeholder("a", (10, 10), "float64") b = pt.make_placeholder("b", (10, 10), "float64") + _prev_debug_enabled = pt.DEBUG_ENABLED + pt.DEBUG_ENABLED = True + # res1 and res2 are defined on different lines and should have different # CreatedAt tags. res1 = a+b @@ -892,6 +895,8 @@ def test_created_at(): assert "test_created_at" in s assert "a+b" in s + pt.DEBUG_ENABLED = _prev_debug_enabled + # }}} From 07690673d0b161ec9e94ed10cf99b912b2704817 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 17 Nov 2021 11:19:40 -0600 Subject: [PATCH 057/178] CHERRY-PICK: Preserve High-Level Info in the Pymbolic expressions --- pytato/array.py | 44 ++++++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 964ff351f..69efdcb1f 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -585,25 +585,37 @@ def _unary_op(self, op: Any) -> Array: axes=_get_default_axes(self.ndim), var_to_reduction_descr=Map()) - __mul__ = partialmethod(_binary_op, operator.mul) - __rmul__ = partialmethod(_binary_op, operator.mul, reverse=True) - - __add__ = partialmethod(_binary_op, operator.add) - __radd__ = partialmethod(_binary_op, operator.add, reverse=True) - - __sub__ = partialmethod(_binary_op, operator.sub) - __rsub__ = partialmethod(_binary_op, operator.sub, reverse=True) - - __floordiv__ = partialmethod(_binary_op, operator.floordiv) - __rfloordiv__ = partialmethod(_binary_op, operator.floordiv, reverse=True) - - __truediv__ = partialmethod(_binary_op, operator.truediv, + # NOTE: Initializing the expression to "prim.Product(expr1, expr2)" is + # essential as opposed to performing "expr1 * expr2". This is to account + # for pymbolic's implementation of the "*" operator which might not + # instantiate the node corresponding to the operation when one of + # the operands is the neutral element of the operation. + # + # For the same reason 'prim.(Sum|FloorDiv|Quotient)' is preferred over the + # python operators on the operands. + + __mul__ = partialmethod(_binary_op, lambda l, r: prim.Product((l, r))) + __rmul__ = partialmethod(_binary_op, lambda l, r: prim.Product((l, r)), + reverse=True) + + __add__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, r))) + __radd__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, r)), + reverse=True) + + __sub__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, -r))) + __rsub__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, -r)), + reverse=True) + + __floordiv__ = partialmethod(_binary_op, prim.FloorDiv) + __rfloordiv__ = partialmethod(_binary_op, prim.FloorDiv, reverse=True) + + __truediv__ = partialmethod(_binary_op, prim.Quotient, get_result_type=_truediv_result_type) - __rtruediv__ = partialmethod(_binary_op, operator.truediv, + __rtruediv__ = partialmethod(_binary_op, prim.Quotient, get_result_type=_truediv_result_type, reverse=True) - __pow__ = partialmethod(_binary_op, operator.pow) - __rpow__ = partialmethod(_binary_op, operator.pow, reverse=True) + __pow__ = partialmethod(_binary_op, prim.Power) + __rpow__ = partialmethod(_binary_op, prim.Power, reverse=True) __neg__ = partialmethod(_unary_op, operator.neg) From 8b3a13b3243055523ab0d3b31edb28b2ed9aeac0 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 25 May 2022 21:07:29 -0500 Subject: [PATCH 058/178] [CHERRY-PICK]: Call BranchMorpher after dw deduplication --- pytato/transform/__init__.py | 37 +++++++++++++++++++++++++++++++----- test/test_codegen.py | 2 +- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index a75083f9a..cd2a39674 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -224,10 +224,9 @@ def rec(self, # type: ignore[override] # type-ignore-reason: CachedMapper.rec's return type is imprecise return super().rec(expr) # type: ignore[return-value] - # type-ignore-reason: specialized variant of super-class' rec method - def __call__(self, # type: ignore[override] - expr: CopyMapperResultT) -> CopyMapperResultT: - return self.rec(expr) + # type-ignore reason: incompatible type with Mapper.rec + def __call__(self, expr: MappedT) -> MappedT: # type: ignore[override] + return self.rec(expr) # type: ignore[no-any-return] def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...] ) -> Tuple[IndexOrShapeExpr, ...]: @@ -1569,6 +1568,33 @@ def tag_user_nodes( # }}} +# {{{ BranchMorpher + +class BranchMorpher(CopyMapper): + """ + A mapper that replaces equal segments of graphs with identical objects. + """ + def __init__(self) -> None: + super().__init__() + self.result_cache: Dict[ArrayOrNames, ArrayOrNames] = {} + + def cache_key(self, expr: CachedMapperT) -> Any: + return (id(expr), expr) + + # type-ignore reason: incompatible with Mapper.rec + def rec(self, expr: MappedT) -> MappedT: # type: ignore[override] + rec_expr = super().rec(expr) + try: + # type-ignored because 'result_cache' maps to ArrayOrNames + return self.result_cache[rec_expr] # type: ignore[return-value] + except KeyError: + self.result_cache[rec_expr] = rec_expr + # type-ignored because of super-class' relaxed types + return rec_expr # type: ignore[no-any-return] + +# }}} + + # {{{ deduplicate_data_wrappers def _get_data_dedup_cache_key(ary: DataInterface) -> Hashable: @@ -1658,8 +1684,9 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames: len(data_wrapper_cache), data_wrappers_encountered - len(data_wrapper_cache)) - return array_or_names + return BranchMorpher()(array_or_names) # }}} + # vim: foldmethod=marker diff --git a/test/test_codegen.py b/test/test_codegen.py index 874906528..46e7a3186 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1556,7 +1556,7 @@ def test_zero_size_cl_array_dedup(ctx_factory): x4 = pt.make_data_wrapper(x_cl2) out = pt.make_dict_of_named_arrays({"out1": 2*x1, - "out2": 2*x2, + "out2": 3*x2, "out3": x3 + x4 }) From 8e20d15eafb9c60b991da82967ffadcca98861a0 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 4 Aug 2023 13:18:12 -0500 Subject: [PATCH 059/178] Define __attrs_post_init__ only if __debug__, for all Array classes --- pytato/array.py | 18 ++++++++++-------- pytato/function.py | 11 ++++++----- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index bbf4ae739..067c1a08d 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -745,10 +745,11 @@ class AbstractResultWithNamedArrays(Mapping[str, NamedArray], Taggable, ABC): def _is_eq_valid(self) -> bool: return self.__class__.__eq__ is AbstractResultWithNamedArrays.__eq__ - def __attrs_post_init__(self) -> None: - # ensure that a developer does not uses dataclass' "__eq__" - # or "__hash__" implementation as they have exponential complexity. - assert self._is_eq_valid() + if __debug__: + def __attrs_post_init__(self) -> None: + # ensure that a developer does not uses dataclass' "__eq__" + # or "__hash__" implementation as they have exponential complexity. + assert self._is_eq_valid() @abstractmethod def __contains__(self, name: object) -> bool: @@ -1450,10 +1451,11 @@ class Reshape(IndexRemappingBase): _mapper_method: ClassVar[str] = "map_reshape" - def __attrs_post_init__(self) -> None: - # FIXME: Get rid of this restriction - assert self.order == "C" - super().__attrs_post_init__() + if __debug__: + def __attrs_post_init__(self) -> None: + # FIXME: Get rid of this restriction + assert self.order == "C" + super().__attrs_post_init__() @property def shape(self) -> ShapeType: diff --git a/pytato/function.py b/pytato/function.py index 6e5d044d2..b053831a0 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -276,11 +276,12 @@ class Call(AbstractResultWithNamedArrays): copy = attrs.evolve - def __attrs_post_init__(self) -> None: - # check that the invocation parameters and the function definition - # parameters agree with each other. - assert frozenset(self.bindings) == self.function.parameters - super().__attrs_post_init__() + if __debug__: + def __attrs_post_init__(self) -> None: + # check that the invocation parameters and the function definition + # parameters agree with each other. + assert frozenset(self.bindings) == self.function.parameters + super().__attrs_post_init__() def __contains__(self, name: object) -> bool: return name in self.function.returns From efcae65bda3679246d24f63ff6a9d8821a1c08df Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sun, 10 Sep 2023 20:41:15 -0500 Subject: [PATCH 060/178] First shot at implementing 'F' ordered array reshapes --- pytato/array.py | 8 +-- pytato/transform/lower_to_index_lambda.py | 65 +++++++++++++++-------- 2 files changed, 45 insertions(+), 28 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 9fdf1de02..f09e2d8be 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1493,8 +1493,6 @@ class Reshape(IndexRemappingBase): _mapper_method: ClassVar[str] = "map_reshape" def __post_init__(self) -> None: - # FIXME: Get rid of this restriction - assert self.order == "C" super().__post_init__() __attrs_post_init__ = __post_init__ @@ -1958,8 +1956,7 @@ def reshape(array: Array, newshape: Union[int, Sequence[int]], """ :param array: array to be reshaped :param newshape: shape of the resulting array - :param order: ``"C"`` or ``"F"``. Layout order of the result array. Only - ``"C"`` allowed for now. + :param order: ``"C"`` or ``"F"``. Layout order of the resulting array. .. note:: @@ -1979,9 +1976,6 @@ def reshape(array: Array, newshape: Union[int, Sequence[int]], if not all(isinstance(axis_len, INT_CLASSES) for axis_len in array.shape): raise ValueError("reshape of arrays with symbolic lengths not allowed") - if order != "C": - raise NotImplementedError("Reshapes to a 'F'-ordered arrays") - newshape_explicit = [] for new_axislen in newshape: diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index a2bb443f0..5c2dfbca1 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -51,27 +51,50 @@ def _get_reshaped_indices(expr: Reshape) -> Tuple[ScalarExpression, ...]: assert expr.size == 1 return () - if expr.order != "C": - raise NotImplementedError(expr.order) - - newstrides: List[IntegralT] = [1] # reshaped array strides - for new_axis_len in reversed(expr.shape[1:]): - assert isinstance(new_axis_len, INT_CLASSES) - newstrides.insert(0, newstrides[0]*new_axis_len) - - flattened_idx = sum(prim.Variable(f"_{i}")*stride - for i, stride in enumerate(newstrides)) - - oldstrides: List[IntegralT] = [1] # input array strides - for axis_len in reversed(expr.array.shape[1:]): - assert isinstance(axis_len, INT_CLASSES) - oldstrides.insert(0, oldstrides[0]*axis_len) - - assert isinstance(expr.array.shape[-1], INT_CLASSES) - oldsizetills = [expr.array.shape[-1]] # input array size till for axes idx - for old_axis_len in reversed(expr.array.shape[:-1]): - assert isinstance(old_axis_len, INT_CLASSES) - oldsizetills.insert(0, oldsizetills[0]*old_axis_len) + if expr.order not in ["C", "F"]: + raise NotImplementedError("Order expected to be 'C' or 'F'", + f" found {expr.order}") + + if expr.order == "C": + newstrides: List[IntegralT] = [1] # reshaped array strides + for new_axis_len in reversed(expr.shape[1:]): + assert isinstance(new_axis_len, INT_CLASSES) + newstrides.insert(0, newstrides[0]*new_axis_len) + + flattened_idx = sum(prim.Variable(f"_{i}")*stride + for i, stride in enumerate(newstrides)) + + oldstrides: List[IntegralT] = [1] # input array strides + for axis_len in reversed(expr.array.shape[1:]): + assert isinstance(axis_len, INT_CLASSES) + oldstrides.insert(0, oldstrides[0]*axis_len) + + assert isinstance(expr.array.shape[-1], INT_CLASSES) + oldsizetills = [expr.array.shape[-1]] # input array size + # till for axes idx + for old_axis_len in reversed(expr.array.shape[:-1]): + assert isinstance(old_axis_len, INT_CLASSES) + oldsizetills.insert(0, oldsizetills[0]*old_axis_len) + + else: + newstrides: List[IntegralT] = [1] # reshaped array strides + for new_axis_len in expr.shape[:-1]: + assert isinstance(new_axis_len, INT_CLASSES) + newstrides.append(newstrides[-1]*new_axis_len) + + flattened_idx = sum(prim.Variable(f"_{i}")*stride + for i, stride in enumerate(newstrides)) + + oldstrides: List[IntegralT] = [1] # input array strides + for axis_len in expr.array.shape[:-1]: + assert isinstance(axis_len, INT_CLASSES) + oldstrides.append(oldstrides[-1]*axis_len) + + assert isinstance(expr.array.shape[0], INT_CLASSES) + oldsizetills = [expr.array.shape[0]] # input array size till for axes idx + for old_axis_len in expr.array.shape[1:]: + assert isinstance(old_axis_len, INT_CLASSES) + oldsizetills.append(oldsizetills[-1]*old_axis_len) return tuple(((flattened_idx % sizetill) // stride) for stride, sizetill in zip(oldstrides, oldsizetills)) From 35c6d1fe084b82ac1282b70c0eb9471aab15a460 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sun, 10 Sep 2023 21:38:24 -0500 Subject: [PATCH 061/178] Remove restriction on reshape order --- pytato/array.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index a981ee20e..bbfdafda0 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1453,8 +1453,6 @@ class Reshape(IndexRemappingBase): if __debug__: def __attrs_post_init__(self) -> None: - # FIXME: Get rid of this restriction - assert self.order == "C" super().__attrs_post_init__() @property From 86233c647e5e2854b15b74a66a15aa7aa62f29ff Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Sat, 14 Oct 2023 15:43:55 -0500 Subject: [PATCH 062/178] work around mypy/attrs issue --- pytato/stringifier.py | 7 +++++-- pytato/visualization/dot.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pytato/stringifier.py b/pytato/stringifier.py index 8aac8d340..269c9a546 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -95,7 +95,9 @@ def _map_generic_array(self, expr: Array, depth: int) -> str: if depth > self.truncation_depth: return self.truncation_string - fields = tuple(field.name for field in attrs.fields(type(expr))) + # type-ignore-reason: https://github.com/python/mypy/issues/16254 + fields = tuple(field.name + for field in attrs.fields(type(expr))) # type: ignore[misc] if expr.ndim <= 1: # prettify: if ndim <=1 'expr.axes' would be trivial, @@ -153,7 +155,8 @@ def _get_field_val(field: str) -> str: return (f"{type(expr).__name__}(" + ", ".join(f"{field.name}={_get_field_val(field.name)}" - for field in attrs.fields(type(expr))) + # type-ignore-reason: https://github.com/python/mypy/issues/16254 + for field in attrs.fields(type(expr))) # type: ignore[misc] + ")") def map_loopy_call(self, expr: LoopyCall, depth: int) -> str: diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index d84b3aaec..32a2ae5b0 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -184,7 +184,8 @@ def handle_unsupported_array(self, # type: ignore[override] # Default handler, does its best to guess how to handle fields. info = self.get_common_dot_info(expr) - for field in attrs.fields(type(expr)): + # type-ignore-reason: https://github.com/python/mypy/issues/16254 + for field in attrs.fields(type(expr)): # type: ignore[misc] if field.name in info.fields: continue attr = getattr(expr, field.name) From 3b6fdade9d66dd53e3eb2d8b54ac607025df4645 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Sat, 14 Oct 2023 15:58:55 -0500 Subject: [PATCH 063/178] fix for fields --- pytato/array.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 0d1ff5b8a..b1a5ecd98 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -517,15 +517,15 @@ def T(self) -> Array: @memoize_method def __hash__(self) -> int: - attrs = [] - for field in self._fields: + attrs_filtered: List[Any] = [] + for field in attrs.fields(type(self)): # type: ignore[misc] attr = getattr(self, field) if field == "tags": attr = Taggable.__hash__(self) if isinstance(attr, dict): attr = frozenset(attr.items()) - attrs.append(attr) - return hash(tuple(attrs)) + attrs_filtered.append(attr) + return hash(tuple(attrs_filtered)) def __eq__(self, other: Any) -> bool: if self is other: @@ -1796,7 +1796,7 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: # # No need to dispatch to superclass: fields() automatically gives us # fields from the entire class hierarchy. - for f in fields(self.__class__): + for f in fields(self.__class__): # type: ignore[misc] key_builder.rec(key_hash, getattr(self, f.name)) def short_str(self, maxlen: int = 100) -> str: @@ -1831,7 +1831,7 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: # # No need to dispatch to superclass: fields() automatically gives us # fields from the entire class hierarchy. - for f in fields(self.__class__): + for f in fields(self.__class__): # type: ignore[misc] key_builder.rec(key_hash, getattr(self, f.name)) def short_str(self, maxlen: int = 100) -> str: From 8a390a583f8450a2a85889cd31016ac40fdde5c8 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Wed, 1 Nov 2023 16:59:38 -0500 Subject: [PATCH 064/178] Update comments a little --- pytato/transform/__init__.py | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 509718014..62c7a3d40 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -245,26 +245,14 @@ def rec(self, # type: ignore[override] expr: CopyMapperResultT) -> CopyMapperResultT: # type-ignore-reason: CachedMapper.rec's return type is imprecise return super().rec(expr) # type: ignore[return-value] - # DISABLED/REPLACED FROM MAIN - # # type-ignore-reason: specialized variant of super-class' rec method - # def rec(self, # type: ignore[override] - # expr: CopyMapperResultT) -> CopyMapperResultT: - # # type-ignore-reason: CachedMapper.rec's return type is imprecise - # return super().rec(expr) # type: ignore[return-value] - # ----- PREVIOUS CODE IN MAIN - # # type-ignore reason: incompatible type with Mapper.rec - # def __call__(self, expr: MappedT) -> MappedT: # type: ignore[override] - # return self.rec(expr) # type: ignore[no-any-return] - # --------------------------- - # ------- CURRENT CODE IN MAIN - # # type-ignore-reason: specialized variant of super-class' rec method - # def __call__(self, # type: ignore[override] - # expr: CopyMapperResultT) -> CopyMapperResultT: - # return self.rec(expr) - # ------------------------------------------------------ - # --------- CURRENT CODE IN CEESD - __call__ = rec - # ------------------------------- + + # REPLACED WITH NEW CODE FROM MAIN + # __call__ = rec + # ------------------------------- + # type-ignore-reason: specialized variant of super-class' rec method + def __call__(self, # type: ignore[override] + expr: CopyMapperResultT) -> CopyMapperResultT: + return self.rec(expr) def clone_for_callee(self: _SelfMapper) -> _SelfMapper: """ From 870849a106c4100e3f3b5a4929e3179cd5ece2de Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Thu, 9 Nov 2023 12:44:48 -0600 Subject: [PATCH 065/178] attempt to fix tag issue --- pytato/distributed/tags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/distributed/tags.py b/pytato/distributed/tags.py index 41ae3273c..95067ee66 100644 --- a/pytato/distributed/tags.py +++ b/pytato/distributed/tags.py @@ -106,7 +106,7 @@ def set_union( next_tag = base_tag assert isinstance(all_tags, frozenset) - for sym_tag in sorted(all_tags): + for sym_tag in sorted(all_tags, key=lambda tag: repr(tag)): sym_tag_to_int_tag[sym_tag] = next_tag next_tag += 1 From 060f864e2153c9c6d650ad919ff1440dc795e136 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 9 Nov 2023 16:53:19 -0600 Subject: [PATCH 066/178] number_distributed_tags: non-set, non-sorted numbering --- pytato/distributed/__init__.py | 2 +- pytato/distributed/tags.py | 46 ++++++++++------------------------ test/test_distributed.py | 10 ++++++-- 3 files changed, 22 insertions(+), 36 deletions(-) diff --git a/pytato/distributed/__init__.py b/pytato/distributed/__init__.py index 4354b2f0f..ee0ff39de 100644 --- a/pytato/distributed/__init__.py +++ b/pytato/distributed/__init__.py @@ -23,7 +23,7 @@ .. class:: CommTagType A type representing a communication tag. Communication tags must be - hashable and totally ordered (and hence comparable). + hashable. .. class:: ShapeType diff --git a/pytato/distributed/tags.py b/pytato/distributed/tags.py index 41ae3273c..4b97b2d50 100644 --- a/pytato/distributed/tags.py +++ b/pytato/distributed/tags.py @@ -31,7 +31,7 @@ """ -from typing import TYPE_CHECKING, Tuple, FrozenSet, Optional, TypeVar +from typing import TYPE_CHECKING, Tuple, TypeVar from pytato.distributed.partition import DistributedGraphPartition @@ -62,53 +62,33 @@ def number_distributed_tags( This is a potentially heavyweight MPI-collective operation on *mpi_communicator*. - - .. note:: - - This function requires that symbolic tags are comparable. """ - tags = frozenset({ + from pytools import flatten + + tags = tuple([ recv.comm_tag for part in partition.parts.values() for recv in part.name_to_recv_node.values() - } | { + ] + [ send.comm_tag for part in partition.parts.values() for sends in part.name_to_send_nodes.values() - for send in sends}) - - from mpi4py import MPI - - def set_union( - set_a: FrozenSet[T], set_b: FrozenSet[T], - mpi_data_type: Optional[MPI.Datatype]) -> FrozenSet[T]: - assert mpi_data_type is None - assert isinstance(set_a, frozenset) - assert isinstance(set_b, frozenset) - - return set_a | set_b + for send in sends]) root_rank = 0 - set_union_mpi_op = MPI.Op.Create( - # type ignore reason: mpi4py misdeclares op functions as returning - # None. - set_union, # type: ignore[arg-type] - commute=True) - try: - all_tags = mpi_communicator.reduce( - tags, set_union_mpi_op, root=root_rank) - finally: - set_union_mpi_op.Free() + all_tags = mpi_communicator.gather(tags, root=root_rank) if mpi_communicator.rank == root_rank: sym_tag_to_int_tag = {} next_tag = base_tag - assert isinstance(all_tags, frozenset) + assert isinstance(all_tags, list) + assert len(all_tags) == mpi_communicator.size - for sym_tag in sorted(all_tags): - sym_tag_to_int_tag[sym_tag] = next_tag - next_tag += 1 + for sym_tag in flatten(all_tags): # type: ignore[no-untyped-call] + if sym_tag not in sym_tag_to_int_tag: + sym_tag_to_int_tag[sym_tag] = next_tag + next_tag += 1 mpi_communicator.bcast((sym_tag_to_int_tag, next_tag), root=root_rank) else: diff --git a/test/test_distributed.py b/test/test_distributed.py index f7a8e5b4c..c36f4caae 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -266,7 +266,7 @@ def _do_test_distributed_execution_random_dag(ctx_factory): ntests = 10 for i in range(ntests): seed = 120 + i - print(f"Step {i} {seed}") + print(f"Step {i} {seed=}") # {{{ compute value with communication @@ -278,7 +278,13 @@ def gen_comm(rdagc): nonlocal comm_tag comm_tag += 1 - tag = (comm_tag, _RandomDAGTag) # noqa: B023 + + if comm_tag % 5 == 1: + tag = (comm_tag, frozenset([_RandomDAGTag, _RandomDAGTag])) + elif comm_tag % 5 == 2: + tag = (comm_tag, (_RandomDAGTag,)) + else: + tag = (comm_tag, _RandomDAGTag) # noqa: B023 inner = make_random_dag(rdagc) return pt.staple_distributed_send( From 65d014297d01bb767a11555f32b361193b5bfd28 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 9 Nov 2023 17:33:10 -0600 Subject: [PATCH 067/178] make the test a bit more difficult --- test/test_distributed.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_distributed.py b/test/test_distributed.py index c36f4caae..3a3e785fc 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -279,10 +279,10 @@ def gen_comm(rdagc): nonlocal comm_tag comm_tag += 1 - if comm_tag % 5 == 1: - tag = (comm_tag, frozenset([_RandomDAGTag, _RandomDAGTag])) + if comm_tag % 5 == 1 or 1: + tag = (comm_tag, frozenset([_RandomDAGTag, "a", comm_tag])) elif comm_tag % 5 == 2: - tag = (comm_tag, (_RandomDAGTag,)) + tag = (comm_tag, (_RandomDAGTag, "b")) else: tag = (comm_tag, _RandomDAGTag) # noqa: B023 From 3ebfcfd30f9a75e3d046e637712e53e835ea2719 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 14 Nov 2023 11:51:43 -0600 Subject: [PATCH 068/178] undo mypy ignores --- pytato/array.py | 6 +++--- pytato/stringifier.py | 7 ++----- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index b1a5ecd98..4511b4f96 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -518,7 +518,7 @@ def T(self) -> Array: @memoize_method def __hash__(self) -> int: attrs_filtered: List[Any] = [] - for field in attrs.fields(type(self)): # type: ignore[misc] + for field in attrs.fields(type(self)): attr = getattr(self, field) if field == "tags": attr = Taggable.__hash__(self) @@ -1796,7 +1796,7 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: # # No need to dispatch to superclass: fields() automatically gives us # fields from the entire class hierarchy. - for f in fields(self.__class__): # type: ignore[misc] + for f in fields(self.__class__): key_builder.rec(key_hash, getattr(self, f.name)) def short_str(self, maxlen: int = 100) -> str: @@ -1831,7 +1831,7 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: # # No need to dispatch to superclass: fields() automatically gives us # fields from the entire class hierarchy. - for f in fields(self.__class__): # type: ignore[misc] + for f in fields(self.__class__): key_builder.rec(key_hash, getattr(self, f.name)) def short_str(self, maxlen: int = 100) -> str: diff --git a/pytato/stringifier.py b/pytato/stringifier.py index 269c9a546..b5172e768 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -95,9 +95,7 @@ def _map_generic_array(self, expr: Array, depth: int) -> str: if depth > self.truncation_depth: return self.truncation_string - # type-ignore-reason: https://github.com/python/mypy/issues/16254 - fields = tuple(field.name - for field in attrs.fields(type(expr))) # type: ignore[misc] + fields = tuple(field.name for field in attrs.fields(type(expr))) if expr.ndim <= 1: # prettify: if ndim <=1 'expr.axes' would be trivial, @@ -155,8 +153,7 @@ def _get_field_val(field: str) -> str: return (f"{type(expr).__name__}(" + ", ".join(f"{field.name}={_get_field_val(field.name)}" - # type-ignore-reason: https://github.com/python/mypy/issues/16254 - for field in attrs.fields(type(expr))) # type: ignore[misc] + for field in attrs.fields(type(expr))) + ")") def map_loopy_call(self, expr: LoopyCall, depth: int) -> str: From eb1c052539e90412d5fad022ee38895309e0ea08 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 14 Nov 2023 17:29:46 -0600 Subject: [PATCH 069/178] rewrite to use a new field in Array, non_equality_tags --- pytato/array.py | 69 ++++++++++++++++++------------------ pytato/cmath.py | 3 +- pytato/equality.py | 14 ++------ pytato/stringifier.py | 2 ++ pytato/transform/__init__.py | 54 +++++++--------------------- pytato/utils.py | 10 ++++-- pytato/visualization/dot.py | 35 ++++++++++++++---- test/test_pytato.py | 21 ++++------- 8 files changed, 95 insertions(+), 113 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 4511b4f96..99f645d37 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -181,7 +181,7 @@ import pymbolic.primitives as prim from pymbolic import var from pytools import memoize_method -from pytools.tag import Tag, Taggable, ToTagSetConvertible +from pytools.tag import Tag, Taggable from pytato.scalar_expr import (ScalarType, SCALAR_CLASSES, ScalarExpression, IntegralT, @@ -448,6 +448,10 @@ class Array(Taggable): axes: AxesT = attrs.field(kw_only=True) tags: FrozenSet[Tag] = attrs.field(kw_only=True) + # These are automatically excluded from equality in EqualityComparer + non_equality_tags: FrozenSet[Tag] = attrs.field(kw_only=True, hash=False, + default=None) + _mapper_method: ClassVar[str] # disallow numpy arithmetic from taking precedence @@ -515,18 +519,6 @@ def T(self) -> Array: tags=_get_default_tags(), axes=_get_default_axes(self.ndim)) - @memoize_method - def __hash__(self) -> int: - attrs_filtered: List[Any] = [] - for field in attrs.fields(type(self)): - attr = getattr(self, field) - if field == "tags": - attr = Taggable.__hash__(self) - if isinstance(attr, dict): - attr = frozenset(attr.items()) - attrs_filtered.append(attr) - return hash(tuple(attrs_filtered)) - def __eq__(self, other: Any) -> bool: if self is other: return True @@ -681,12 +673,6 @@ def __repr__(self) -> str: from pytato.stringifier import Reprifier return Reprifier()(self) - def tagged(self, tags: ToTagSetConvertible) -> Array: - from pytato.equality import preprocess_tags_for_equality - from pytools.tag import normalize_tags - new_tags = preprocess_tags_for_equality(normalize_tags(tags)) - return super().tagged(new_tags) - # }}} @@ -1852,24 +1838,21 @@ def __repr__(self) -> str: return "\n " + "\n ".join([str(f) for f in self.frames]) -def _get_default_tags(existing_tags: Optional[FrozenSet[Tag]] = None) \ - -> FrozenSet[Tag]: +def _get_created_at_tag() -> Optional[Tag]: import traceback from pytato.tags import CreatedAt - from pytato import DEBUG_ENABLED + if not __debug__: + return None - # This has a significant overhead, so only enable it when PYTATO_DEBUG is - # enabled. - if DEBUG_ENABLED and ( - existing_tags is None - or not any(isinstance(tag, CreatedAt) for tag in existing_tags)): - frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) + frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) for s in traceback.extract_stack()) - c = CreatedAt(_PytatoStackSummary(frames)) - return frozenset((c,)) - else: - return frozenset() + + return CreatedAt(_PytatoStackSummary(frames)) + + +def _get_default_tags() -> FrozenSet[Tag]: + return frozenset() def matmul(x1: Array, x2: Array) -> Array: @@ -1931,6 +1914,7 @@ def roll(a: Array, shift: int, axis: Optional[int] = None) -> Array: return Roll(a, shift, axis, tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(a.ndim)) @@ -1953,6 +1937,7 @@ def transpose(a: Array, axes: Optional[Sequence[int]] = None) -> Array: return AxisPermutation(a, tuple(axes), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(a.ndim)) @@ -1987,6 +1972,7 @@ def stack(arrays: Sequence[Array], axis: int = 0) -> Array: return Stack(tuple(arrays), axis, tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(arrays[0].ndim+1)) @@ -2022,6 +2008,7 @@ def shape_except_axis(ary: Array) -> ShapeType: return Concatenate(tuple(arrays), axis, tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(arrays[0].ndim)) @@ -2085,6 +2072,7 @@ def reshape(array: Array, newshape: Union[int, Sequence[int]], return Reshape(array, tuple(newshape_explicit), order, tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(newshape_explicit))) @@ -2128,7 +2116,8 @@ def make_placeholder(name: str, f" expected {len(shape)}, got {len(axes)}.") return Placeholder(name=name, shape=shape, dtype=dtype, axes=axes, - tags=(tags | _get_default_tags(tags))) + tags=(tags | _get_default_tags()), + non_equality_tags=frozenset({_get_created_at_tag()}),) def make_size_param(name: str, @@ -2142,7 +2131,8 @@ def make_size_param(name: str, :param tags: implementation tags """ _check_identifier(name, optional=False) - return SizeParam(name, tags=(tags | _get_default_tags(tags))) + return SizeParam(name, tags=(tags | _get_default_tags()), + non_equality_tags=frozenset({_get_created_at_tag()}),) def make_data_wrapper(data: DataInterface, @@ -2181,7 +2171,8 @@ def make_data_wrapper(data: DataInterface, raise ValueError("'axes' dimensionality mismatch:" f" expected {len(shape)}, got {len(axes)}.") - return DataWrapper(data, shape, axes=axes, tags=(tags | _get_default_tags(tags))) + return DataWrapper(data, shape, axes=axes, tags=(tags | _get_default_tags()), + non_equality_tags=frozenset({_get_created_at_tag()}),) # }}} @@ -2212,6 +2203,7 @@ def full(shape: ConvertibleToShape, fill_value: ScalarType, return IndexLambda(expr=fill_value, shape=shape, dtype=dtype, bindings=immutabledict(), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(shape)), var_to_reduction_descr=immutabledict()) @@ -2258,6 +2250,7 @@ def eye(N: int, M: Optional[int] = None, k: int = 0, # noqa: N803 return IndexLambda(expr=parse(f"1 if ((_1 - _0) == {k}) else 0"), shape=(N, M), dtype=dtype, bindings=immutabledict({}), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(2), var_to_reduction_descr=immutabledict()) @@ -2353,6 +2346,7 @@ def arange(*args: Any, **kwargs: Any) -> Array: return IndexLambda(expr=start + Variable("_0") * step, shape=(size,), dtype=dtype, bindings=immutabledict(), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(1), var_to_reduction_descr=immutabledict()) @@ -2464,6 +2458,7 @@ def logical_not(x: ArrayOrScalar) -> Union[Array, bool]: dtype=np.dtype(np.bool_), bindings={"_in0": x}, tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(x.shape)), var_to_reduction_descr=immutabledict()) @@ -2520,6 +2515,7 @@ def where(condition: ArrayOrScalar, dtype=dtype, bindings=immutabledict(bindings), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(result_shape)), var_to_reduction_descr=immutabledict()) @@ -2618,6 +2614,7 @@ def make_index_lambda( shape=shape, dtype=dtype, tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(shape)), var_to_reduction_descr=immutabledict (processed_var_to_reduction_descr)) @@ -2703,6 +2700,7 @@ def broadcast_to(array: Array, shape: ShapeType) -> Array: dtype=array.dtype, bindings=immutabledict({"in": array}), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(shape)), var_to_reduction_descr=immutabledict()) @@ -2777,6 +2775,7 @@ def expand_dims(array: Array, axis: Union[Tuple[int, ...], int]) -> Array: return Reshape(array=array, newshape=tuple(new_shape), order="C", tags=(_get_default_tags() | {ExpandedDimsReshape(tuple(normalized_axis))}), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(new_shape))) # }}} diff --git a/pytato/cmath.py b/pytato/cmath.py index 38c520c7e..9f8e8c7fa 100644 --- a/pytato/cmath.py +++ b/pytato/cmath.py @@ -59,7 +59,7 @@ import pymbolic.primitives as prim from typing import Tuple, Optional from pytato.array import (Array, ArrayOrScalar, IndexLambda, _dtype_any, - _get_default_axes, _get_default_tags) + _get_default_axes, _get_default_tags, _get_created_at_tag) from pytato.scalar_expr import SCALAR_CLASSES from pymbolic import var from immutabledict import immutabledict @@ -115,6 +115,7 @@ def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...], tuple(sym_args)), shape=shape, dtype=ret_dtype, bindings=immutabledict(bindings), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(shape)), var_to_reduction_descr=immutabledict(), ) diff --git a/pytato/equality.py b/pytato/equality.py index 79b7427c4..76c21b4ed 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -24,7 +24,7 @@ THE SOFTWARE. """ -from typing import Any, Callable, Dict, TYPE_CHECKING, Tuple, Union, FrozenSet +from typing import Any, Callable, Dict, TYPE_CHECKING, Tuple, Union from pytato.array import (AdvancedIndexInContiguousAxes, AdvancedIndexInNoncontiguousAxes, AxisPermutation, BasicIndex, Concatenate, DataWrapper, Einsum, @@ -34,7 +34,7 @@ from pytato.function import Call, NamedCallResult, FunctionDefinition from pytools import memoize_method -from pytools.tag import Tag, IgnoredForEqualityTag, Taggable +from pytools.tag import Taggable if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult @@ -42,22 +42,12 @@ __doc__ = """ -.. autofunction:: preprocess_tags_for_equality .. autoclass:: EqualityComparer """ - ArrayOrNames = Union[Array, AbstractResultWithNamedArrays] -def preprocess_tags_for_equality(tags: FrozenSet[Tag]) -> FrozenSet[Tag]: - """Remove tags of :class:`~pytools.tag.IgnoredForEqualityTag` for equality - comparison.""" - return frozenset(tag - for tag in tags - if not isinstance(tag, IgnoredForEqualityTag)) - - # {{{ EqualityComparer class EqualityComparer: diff --git a/pytato/stringifier.py b/pytato/stringifier.py index b5172e768..9afb887c0 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -97,6 +97,8 @@ def _map_generic_array(self, expr: Array, depth: int) -> str: fields = tuple(field.name for field in attrs.fields(type(expr))) + fields = tuple(field for field in fields if field != "non_equality_tags") + if expr.ndim <= 1: # prettify: if ndim <=1 'expr.axes' would be trivial, # => don't print. diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 40b416750..9cd1ed16a 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -33,7 +33,7 @@ from immutabledict import immutabledict from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic, List, Mapping, Iterable, Tuple, Optional, TYPE_CHECKING, - Hashable) + Hashable, cast) from pytato.array import ( Array, IndexLambda, Placeholder, Stack, Roll, @@ -82,7 +82,6 @@ .. autofunction:: copy_dict_of_named_arrays .. autofunction:: get_dependencies .. autofunction:: map_and_copy -.. autofunction:: remove_tags_of_type .. autofunction:: materialize_with_mpms .. autofunction:: deduplicate_data_wrappers .. automodule:: pytato.transform.lower_to_index_lambda @@ -207,8 +206,7 @@ def __init__(self) -> None: def get_cache_key(self, expr: ArrayOrNames) -> Hashable: return expr - # type-ignore-reason: incompatible with super class - def rec(self, expr: ArrayOrNames) -> CachedMapperT: # type: ignore[override] + def rec(self, expr: ArrayOrNames) -> CachedMapperT: key = self.get_cache_key(expr) try: return self._cache[key] @@ -219,9 +217,7 @@ def rec(self, expr: ArrayOrNames) -> CachedMapperT: # type: ignore[override] return result # type: ignore[no-any-return] if TYPE_CHECKING: - # type-ignore-reason: incompatible with super class - def __call__(self, expr: ArrayOrNames # type: ignore[override] - ) -> CachedMapperT: + def __call__(self, expr: ArrayOrNames) -> CachedMapperT: return self.rec(expr) # }}} @@ -241,15 +237,10 @@ class CopyMapper(CachedMapper[ArrayOrNames]): This does not copy the data of a :class:`pytato.array.DataWrapper`. """ if TYPE_CHECKING: - # type-ignore-reason: specialized variant of super-class' rec method - def rec(self, # type: ignore[override] - expr: CopyMapperResultT) -> CopyMapperResultT: - # type-ignore-reason: CachedMapper.rec's return type is imprecise - return super().rec(expr) # type: ignore[return-value] - - # type-ignore-reason: specialized variant of super-class' rec method - def __call__(self, # type: ignore[override] - expr: CopyMapperResultT) -> CopyMapperResultT: + def rec(self, expr: CopyMapperResultT) -> CopyMapperResultT: + return cast(CopyMapperResultT, super().rec(expr)) + + def __call__(self, expr: CopyMapperResultT) -> CopyMapperResultT: return self.rec(expr) def clone_for_callee(self: _SelfMapper) -> _SelfMapper: @@ -1193,17 +1184,13 @@ def __init__(self) -> None: super().__init__() self.topological_order: List[Array] = [] - # type-ignore-reason: dropped the extra `*args, **kwargs`. - def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] + def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) - # type-ignore-reason: dropped the extra `*args, **kwargs`. - def post_visit(self, expr: Any) -> None: # type: ignore[override] + def post_visit(self, expr: Any) -> None: self.topological_order.append(expr) - # type-ignore-reason: dropped the extra `*args, **kwargs`. - def map_function_definition(self, # type: ignore[override] - expr: FunctionDefinition) -> None: + def map_function_definition(self, expr: FunctionDefinition) -> None: # do nothing as it includes arrays from a different namespace. return @@ -1227,8 +1214,7 @@ def clone_for_callee(self: _SelfMapper) -> _SelfMapper: # than Mapper.__init__ and does not have map_fn return type(self)(self.map_fn) # type: ignore[call-arg,attr-defined] - # type-ignore-reason:incompatible with Mapper.rec() - def rec(self, expr: MappedT) -> MappedT: # type: ignore[override] + def rec(self, expr: MappedT) -> MappedT: if expr in self._cache: # type-ignore-reason: parametric Mapping types aren't a thing return self._cache[expr] # type: ignore[return-value] @@ -1239,8 +1225,7 @@ def rec(self, expr: MappedT) -> MappedT: # type: ignore[override] return result # type: ignore[return-value] if TYPE_CHECKING: - # type-ignore-reason: Mapper.__call__ returns Any - def __call__(self, expr: MappedT) -> MappedT: # type: ignore[override] + def __call__(self, expr: MappedT) -> MappedT: return self.rec(expr) # }}} @@ -1532,21 +1517,6 @@ def map_and_copy(expr: MappedT, return CachedMapAndCopyMapper(map_fn)(expr) -def remove_tags_of_type(tag_types: Union[type, Tuple[type]], expr: ArrayOrNames - ) -> ArrayOrNames: - def process_node(expr: ArrayOrNames) -> ArrayOrNames: - if isinstance(expr, Array): - return expr.copy(tags=frozenset({ - tag for tag in expr.tags - if not isinstance(tag, tag_types)})) - elif isinstance(expr, AbstractResultWithNamedArrays): - return expr - else: - raise AssertionError(type(expr)) - - return map_and_copy(expr, process_node) - - def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: r""" Materialize nodes in *expr* with MPMS materialization strategy. diff --git a/pytato/utils.py b/pytato/utils.py index 3937711f6..212197a93 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -179,7 +179,8 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501 ) -> ArrayOrScalar: - from pytato.array import _get_default_axes, _get_default_tags + from pytato.array import (_get_default_axes, _get_default_tags, + _get_created_at_tag) if isinstance(a1, SCALAR_CLASSES): a1 = np.dtype(type(a1)).type(a1) @@ -207,6 +208,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, dtype=result_dtype, bindings=immutabledict(bindings), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), var_to_reduction_descr=immutabledict(), axes=_get_default_axes(len(result_shape))) @@ -475,7 +477,8 @@ def _normalized_slice_len(slice_: NormalizedSlice) -> ShapeComponent: def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Array: from pytato.diagnostic import CannotBroadcastError - from pytato.array import _get_default_axes, _get_default_tags + from pytato.array import (_get_default_axes, _get_default_tags, + _get_created_at_tag) # {{{ handle ellipsis @@ -562,6 +565,7 @@ def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Arra ary, tuple(normalized_indices), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(array_idx_shape) + len(i_basic_indices))) else: @@ -569,6 +573,7 @@ def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Arra ary, tuple(normalized_indices), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(array_idx_shape) + len(i_basic_indices))) else: @@ -576,6 +581,7 @@ def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Arra return BasicIndex(ary, tuple(normalized_indices), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes( len([idx for idx in normalized_indices diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 32a2ae5b0..2cddaa2c2 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -171,9 +171,11 @@ def __init__(self) -> None: def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: title = type(expr).__name__ fields = {"addr": hex(id(expr)), - "shape": stringify_shape(expr.shape), - "dtype": str(expr.dtype), - "tags": stringify_tags(expr.tags)} + "shape": stringify_shape(expr.shape), + "dtype": str(expr.dtype), + "tags": stringify_tags(expr.tags), + "non_equality_tags": expr.non_equality_tags, + } edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} return _DotNodeInfo(title, fields, edges) @@ -188,6 +190,7 @@ def handle_unsupported_array(self, # type: ignore[override] for field in attrs.fields(type(expr)): # type: ignore[misc] if field.name in info.fields: continue + attr = getattr(expr, field.name) if isinstance(attr, Array): @@ -356,15 +359,31 @@ def dot_escape(s: str) -> str: return html.escape(s.replace("\\", "\\\\").replace(" ", "_")) +def dot_escape_leave_space(s: str) -> str: + # "\" and HTML are significant in graphviz. + return html.escape(s.replace("\\", "\\\\")) + + # {{{ emit helpers +def _stringify_created_at(non_equality_tags: frozenset[Tag]) -> str: + from pytato.tags import CreatedAt + for tag in non_equality_tags: + if isinstance(tag, CreatedAt): + return tag.traceback.short_str() + + return "" + + def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, str], dot_node_id: str, color: str = "white") -> None: td_attrib = 'border="0"' table_attrib = 'border="0" cellborder="1" cellspacing="0"' - rows = ['%s' - % (td_attrib, dot_escape(title))] + rows = [f"{dot_escape(title)}"] + + non_equality_tags = fields.pop("non_equality_tags", frozenset()) + tooltip = dot_escape_leave_space(_stringify_created_at(non_equality_tags)) for name, field in fields.items(): field_content = dot_escape(field).replace("\n", "
") @@ -372,8 +391,10 @@ def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, str], f"{dot_escape(name)}:" f"{field_content}" ) - table = "\n%s
" % (table_attrib, "".join(rows)) - emit("%s [label=<%s> style=filled fillcolor=%s]" % (dot_node_id, table, color)) + + table = f"\n{''.join(rows)}
" + emit(f"{dot_node_id} [label=<{table}> style=filled fillcolor={color} " + f'tooltip="{tooltip}"]') def _emit_name_cluster( diff --git a/test/test_pytato.py b/test/test_pytato.py index 0583bdce0..ad865c4ef 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -807,9 +807,6 @@ def test_created_at(): a = pt.make_placeholder("a", (10, 10), "float64") b = pt.make_placeholder("b", (10, 10), "float64") - _prev_debug_enabled = pt.DEBUG_ENABLED - pt.DEBUG_ENABLED = True - # res1 and res2 are defined on different lines and should have different # CreatedAt tags. res1 = a+b @@ -821,23 +818,21 @@ def test_created_at(): # {{{ Check that CreatedAt tags are handled correctly for equality - from pytato.equality import preprocess_tags_for_equality - assert res1 == res2 == res3 == res4 - assert res1.tags != res2.tags - assert res3.tags == res4.tags + assert res1.non_equality_tags != res2.non_equality_tags + assert res3.non_equality_tags == res4.non_equality_tags - assert (preprocess_tags_for_equality(res1.tags) - == preprocess_tags_for_equality(res2.tags)) - assert (preprocess_tags_for_equality(res3.tags) - == preprocess_tags_for_equality(res4.tags)) + assert res1.tags == res2.tags + assert res3.tags == res4.tags # }}} from pytato.tags import CreatedAt - created_tag = res1.tags_of_type(CreatedAt) + created_tag = frozenset({tag + for tag in res1.non_equality_tags + if isinstance(tag, CreatedAt)}) assert len(created_tag) == 1 @@ -866,8 +861,6 @@ def test_created_at(): assert "test_created_at" in s assert "a+b" in s - pt.DEBUG_ENABLED = _prev_debug_enabled - # }}} From a5cec505af2390c0d35a789d04edeae10ba4e8a6 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 14 Nov 2023 17:52:00 -0600 Subject: [PATCH 070/178] misc fixes --- pytato/array.py | 2 +- pytato/utils.py | 4 +--- pytato/visualization/dot.py | 2 +- test/test_pytato.py | 18 ++++++------------ 4 files changed, 9 insertions(+), 17 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 99f645d37..bdb4714db 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -450,7 +450,7 @@ class Array(Taggable): # These are automatically excluded from equality in EqualityComparer non_equality_tags: FrozenSet[Tag] = attrs.field(kw_only=True, hash=False, - default=None) + default=frozenset()) _mapper_method: ClassVar[str] diff --git a/pytato/utils.py b/pytato/utils.py index 212197a93..ad2b77377 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -253,8 +253,6 @@ def dim_to_index_lambda_components(expr: ShapeComponent, .. testsetup:: >>> import pytato as pt - >>> from pytato.transform import remove_tags_of_type - >>> from pytato.tags import CreatedAt >>> from pytato.utils import dim_to_index_lambda_components >>> from pytools import UniqueNameGenerator @@ -264,7 +262,7 @@ def dim_to_index_lambda_components(expr: ShapeComponent, >>> expr, bnds = dim_to_index_lambda_components(3*n+8, UniqueNameGenerator()) >>> print(expr) 3*_in + 8 - >>> {"_in": remove_tags_of_type(CreatedAt, bnds["_in"])} + >>> bnds {'_in': SizeParam(name='n')} """ if isinstance(expr, INT_CLASSES): diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 2cddaa2c2..3984d986e 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -366,7 +366,7 @@ def dot_escape_leave_space(s: str) -> str: # {{{ emit helpers -def _stringify_created_at(non_equality_tags: frozenset[Tag]) -> str: +def _stringify_created_at(non_equality_tags: FrozenSet[Tag]) -> str: from pytato.tags import CreatedAt for tag in non_equality_tags: if isinstance(tag, CreatedAt): diff --git a/test/test_pytato.py b/test/test_pytato.py index ad865c4ef..8976d76d4 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -27,8 +27,6 @@ import sys -from typing import cast - import numpy as np import pytest import attrs @@ -451,16 +449,12 @@ def test_array_dot_repr(): x = pt.make_placeholder("x", (10, 4), np.int64) y = pt.make_placeholder("y", (10, 4), np.int64) - def _assert_stripped_repr(ary: pt.Array, expected_repr: str): - from pytato.transform import remove_tags_of_type - from pytato.tags import CreatedAt - ary = cast(pt.Array, remove_tags_of_type(CreatedAt, ary)) - + def _assert_repr(ary: pt.Array, expected_repr: str): expected_str = "".join([c for c in expected_repr if c not in [" ", "\n"]]) result_str = "".join([c for c in repr(ary)if c not in [" ", "\n"]]) assert expected_str == result_str - _assert_stripped_repr( + _assert_repr( 3*x + 4*y, """ IndexLambda( @@ -489,7 +483,7 @@ def _assert_stripped_repr(ary: pt.Array, expected_repr: str): dtype='int64', name='y')})})""") - _assert_stripped_repr( + _assert_repr( pt.roll(x.reshape(2, 20).reshape(-1), 3), """ Roll( @@ -501,7 +495,7 @@ def _assert_stripped_repr(ary: pt.Array, expected_repr: str): newshape=(40), order='C'), shift=3, axis=0)""") - _assert_stripped_repr(y * pt.not_equal(x, 3), + _assert_repr(y * pt.not_equal(x, 3), """ IndexLambda( shape=(10, 4), @@ -521,7 +515,7 @@ def _assert_stripped_repr(ary: pt.Array, expected_repr: str): bindings={'_in0': Placeholder(shape=(10, 4), dtype='int64', name='x')})})""") - _assert_stripped_repr( + _assert_repr( x[y[:, 2:3], x[2, :]], """ AdvancedIndexInContiguousAxes( @@ -536,7 +530,7 @@ def _assert_stripped_repr(ary: pt.Array, expected_repr: str): name='x'), indices=(2, NormalizedSlice(start=0, stop=4, step=1)))))""") - _assert_stripped_repr( + _assert_repr( pt.stack([x[y[:, 2:3], x[2, :]].T, y[x[:, 2:3], y[2, :]].T]), """ Stack( From d9898c9103314afbbf9117c920fc50195d20bc4f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 14 Nov 2023 18:20:41 -0600 Subject: [PATCH 071/178] undo some unecessary changes --- pytato/equality.py | 38 +++++++++++++++++------------------- pytato/stringifier.py | 7 +++++-- pytato/transform/__init__.py | 38 ++++++++++++++++++++++++------------ test/test_pytato.py | 4 ++-- 4 files changed, 51 insertions(+), 36 deletions(-) diff --git a/pytato/equality.py b/pytato/equality.py index 76c21b4ed..42c2978cd 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -34,17 +34,15 @@ from pytato.function import Call, NamedCallResult, FunctionDefinition from pytools import memoize_method -from pytools.tag import Taggable - if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder - __doc__ = """ .. autoclass:: EqualityComparer """ + ArrayOrNames = Union[Array, AbstractResultWithNamedArrays] @@ -107,14 +105,14 @@ def map_placeholder(self, expr1: Placeholder, expr2: Any) -> bool: and expr1.name == expr2.name and expr1.shape == expr2.shape and expr1.dtype == expr2.dtype - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes ) def map_size_param(self, expr1: SizeParam, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.name == expr2.name - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes ) @@ -133,7 +131,7 @@ def map_index_lambda(self, expr1: IndexLambda, expr2: Any) -> bool: if isinstance(dim1, Array) else dim1 == dim2 for dim1, dim2 in zip(expr1.shape, expr2.shape)) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes and expr1.var_to_reduction_descr == expr2.var_to_reduction_descr ) @@ -144,7 +142,7 @@ def map_stack(self, expr1: Stack, expr2: Any) -> bool: and len(expr1.arrays) == len(expr2.arrays) and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.arrays, expr2.arrays)) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes ) @@ -154,7 +152,7 @@ def map_concatenate(self, expr1: Concatenate, expr2: Any) -> bool: and len(expr1.arrays) == len(expr2.arrays) and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.arrays, expr2.arrays)) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes ) @@ -163,7 +161,7 @@ def map_roll(self, expr1: Roll, expr2: Any) -> bool: and expr1.axis == expr2.axis and expr1.shift == expr2.shift and self.rec(expr1.array, expr2.array) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes ) @@ -171,7 +169,7 @@ def map_axis_permutation(self, expr1: AxisPermutation, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.axis_permutation == expr2.axis_permutation and self.rec(expr1.array, expr2.array) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes ) @@ -184,7 +182,7 @@ def _map_index_base(self, expr1: IndexBase, expr2: Any) -> bool: and isinstance(idx2, Array)) else idx1 == idx2 for idx1, idx2 in zip(expr1.indices, expr2.indices)) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes ) @@ -207,7 +205,7 @@ def map_reshape(self, expr1: Reshape, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.newshape == expr2.newshape and self.rec(expr1.array, expr2.array) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes ) @@ -217,7 +215,7 @@ def map_einsum(self, expr1: Einsum, expr2: Any) -> bool: and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.args, expr2.args)) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes and expr1.redn_axis_to_redn_descr == expr2.redn_axis_to_redn_descr ) @@ -225,7 +223,7 @@ def map_einsum(self, expr1: Einsum, expr2: Any) -> bool: def map_named_array(self, expr1: NamedArray, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and self.rec(expr1._container, expr2._container) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes and expr1.name == expr2.name) @@ -239,13 +237,13 @@ def map_loopy_call(self, expr1: LoopyCall, expr2: Any) -> bool: if isinstance(bnd, Array) else bnd == expr2.bindings[name] for name, bnd in expr1.bindings.items()) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags ) def map_loopy_call_result(self, expr1: LoopyCallResult, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and self.rec(expr1._container, expr2._container) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes and expr1.name == expr2.name) @@ -254,7 +252,7 @@ def map_dict_of_named_arrays(self, expr1: DictOfNamedArrays, expr2: Any) -> bool and frozenset(expr1._data.keys()) == frozenset(expr2._data.keys()) and all(self.rec(expr1._data[name], expr2._data[name]) for name in expr1._data) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags ) def map_distributed_send_ref_holder( @@ -264,8 +262,8 @@ def map_distributed_send_ref_holder( and self.rec(expr1.passthrough_data, expr2.passthrough_data) and expr1.send.dest_rank == expr2.send.dest_rank and expr1.send.comm_tag == expr2.send.comm_tag - and Taggable.__eq__(expr1.send, expr2.send) - and Taggable.__eq__(expr1, expr2) + and expr1.send.tags == expr2.send.tags + and expr1.tags == expr2.tags ) def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: @@ -274,7 +272,7 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: and expr1.comm_tag == expr2.comm_tag and expr1.shape == expr2.shape and expr1.dtype == expr2.dtype - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags ) @memoize_method diff --git a/pytato/stringifier.py b/pytato/stringifier.py index 9afb887c0..4fc0fd8bf 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -95,7 +95,9 @@ def _map_generic_array(self, expr: Array, depth: int) -> str: if depth > self.truncation_depth: return self.truncation_string - fields = tuple(field.name for field in attrs.fields(type(expr))) + # type-ignore-reason: https://github.com/python/mypy/issues/16254 + fields = tuple(field.name + for field in attrs.fields(type(expr))) # type: ignore[misc] fields = tuple(field for field in fields if field != "non_equality_tags") @@ -155,7 +157,8 @@ def _get_field_val(field: str) -> str: return (f"{type(expr).__name__}(" + ", ".join(f"{field.name}={_get_field_val(field.name)}" - for field in attrs.fields(type(expr))) + # type-ignore-reason: https://github.com/python/mypy/issues/16254 + for field in attrs.fields(type(expr))) # type: ignore[misc] + ")") def map_loopy_call(self, expr: LoopyCall, depth: int) -> str: diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 9cd1ed16a..e759a7e82 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -33,7 +33,7 @@ from immutabledict import immutabledict from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic, List, Mapping, Iterable, Tuple, Optional, TYPE_CHECKING, - Hashable, cast) + Hashable) from pytato.array import ( Array, IndexLambda, Placeholder, Stack, Roll, @@ -206,7 +206,8 @@ def __init__(self) -> None: def get_cache_key(self, expr: ArrayOrNames) -> Hashable: return expr - def rec(self, expr: ArrayOrNames) -> CachedMapperT: + # type-ignore-reason: incompatible with super class + def rec(self, expr: ArrayOrNames) -> CachedMapperT: # type: ignore[override] key = self.get_cache_key(expr) try: return self._cache[key] @@ -217,7 +218,9 @@ def rec(self, expr: ArrayOrNames) -> CachedMapperT: return result # type: ignore[no-any-return] if TYPE_CHECKING: - def __call__(self, expr: ArrayOrNames) -> CachedMapperT: + # type-ignore-reason: incompatible with super class + def __call__(self, expr: ArrayOrNames # type: ignore[override] + ) -> CachedMapperT: return self.rec(expr) # }}} @@ -237,10 +240,15 @@ class CopyMapper(CachedMapper[ArrayOrNames]): This does not copy the data of a :class:`pytato.array.DataWrapper`. """ if TYPE_CHECKING: - def rec(self, expr: CopyMapperResultT) -> CopyMapperResultT: - return cast(CopyMapperResultT, super().rec(expr)) - - def __call__(self, expr: CopyMapperResultT) -> CopyMapperResultT: + # type-ignore-reason: specialized variant of super-class' rec method + def rec(self, # type: ignore[override] + expr: CopyMapperResultT) -> CopyMapperResultT: + # type-ignore-reason: CachedMapper.rec's return type is imprecise + return super().rec(expr) # type: ignore[return-value] + + # type-ignore-reason: specialized variant of super-class' rec method + def __call__(self, # type: ignore[override] + expr: CopyMapperResultT) -> CopyMapperResultT: return self.rec(expr) def clone_for_callee(self: _SelfMapper) -> _SelfMapper: @@ -1184,13 +1192,17 @@ def __init__(self) -> None: super().__init__() self.topological_order: List[Array] = [] - def get_cache_key(self, expr: ArrayOrNames) -> int: + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] return id(expr) - def post_visit(self, expr: Any) -> None: + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def post_visit(self, expr: Any) -> None: # type: ignore[override] self.topological_order.append(expr) - def map_function_definition(self, expr: FunctionDefinition) -> None: + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def map_function_definition(self, # type: ignore[override] + expr: FunctionDefinition) -> None: # do nothing as it includes arrays from a different namespace. return @@ -1214,7 +1226,8 @@ def clone_for_callee(self: _SelfMapper) -> _SelfMapper: # than Mapper.__init__ and does not have map_fn return type(self)(self.map_fn) # type: ignore[call-arg,attr-defined] - def rec(self, expr: MappedT) -> MappedT: + # type-ignore-reason:incompatible with Mapper.rec() + def rec(self, expr: MappedT) -> MappedT: # type: ignore[override] if expr in self._cache: # type-ignore-reason: parametric Mapping types aren't a thing return self._cache[expr] # type: ignore[return-value] @@ -1225,7 +1238,8 @@ def rec(self, expr: MappedT) -> MappedT: return result # type: ignore[return-value] if TYPE_CHECKING: - def __call__(self, expr: MappedT) -> MappedT: + # type-ignore-reason: Mapper.__call__ returns Any + def __call__(self, expr: MappedT) -> MappedT: # type: ignore[override] return self.rec(expr) # }}} diff --git a/test/test_pytato.py b/test/test_pytato.py index 8976d76d4..be122a925 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -376,7 +376,7 @@ def test_linear_complexity_inequality(): from pytato.equality import EqualityComparer from numpy.random import default_rng - def construct_intestine_graph(depth=90, seed=0): + def construct_intestine_graph(depth=100, seed=0): rng = default_rng(seed) x = pt.make_placeholder("x", shape=(10,), dtype=float) @@ -650,7 +650,7 @@ def post_visit(self, expr): def test_tag_user_nodes_linear_complexity(): from numpy.random import default_rng - def construct_intestine_graph(depth=90, seed=0): + def construct_intestine_graph(depth=100, seed=0): rng = default_rng(seed) x = pt.make_placeholder("x", shape=(10,), dtype=float) y = x From c5c8920ef1f6f97b2ae3e71e7abbbf4af06f7d3c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 15 Nov 2023 15:16:54 -0600 Subject: [PATCH 072/178] more misc fixes --- pytato/array.py | 2 +- pytato/transform/__init__.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index bdb4714db..fb32fb7f5 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1902,7 +1902,7 @@ def roll(a: Array, shift: int, axis: Optional[int] = None) -> Array: if axis is None: if a.ndim > 1: raise NotImplementedError( - "shifing along more than one dimension is unsupported") + "shifting along more than one dimension is unsupported") else: axis = 0 diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index e759a7e82..e4316995b 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -419,8 +419,8 @@ def map_function_definition(self, def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: return Call(self.map_function_definition(expr.function), immutabledict({name: self.rec(bnd) - for name, bnd in expr.bindings.items()}), - tags=expr.tags, + for name, bnd in sorted(expr.bindings.items())}), + tags=expr.tags ) def map_named_call_result(self, expr: NamedCallResult) -> Array: @@ -642,7 +642,7 @@ def map_call(self, expr: Call, *args: Any, **kwargs: Any) -> AbstractResultWithNamedArrays: return Call(self.map_function_definition(expr.function, *args, **kwargs), immutabledict({name: self.rec(bnd, *args, **kwargs) - for name, bnd in expr.bindings.items()}), + for name, bnd in sorted(expr.bindings.items())}), tags=expr.tags, ) @@ -1540,7 +1540,7 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: - MPMS materialization strategy is a greedy materialization algorithm in which any node with more than 1 materialized predecessors and more than - 1 successors is materialized. + 1 successor is materialized. - Materializing here corresponds to tagging a node with :class:`~pytato.tags.ImplStored`. - Does not attempt to materialize sub-expressions in From 4ec3cbf60dffa9a38e1518269db31dc9c7daa6fc Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 15 Nov 2023 15:48:54 -0600 Subject: [PATCH 073/178] copymapper, tests --- pytato/transform/__init__.py | 78 ++++++++++++++++++++++++------------ test/test_pytato.py | 37 +++++++++++++++++ 2 files changed, 89 insertions(+), 26 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index e4316995b..f04f68b11 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -275,7 +275,8 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: bindings=bindings, axes=expr.axes, var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_placeholder(self, expr: Placeholder) -> Array: assert expr.name is not None @@ -283,35 +284,41 @@ def map_placeholder(self, expr: Placeholder) -> Array: shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack) -> Array: arrays = tuple(self.rec(arr) for arr in expr.arrays) - return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags) + return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_concatenate(self, expr: Concatenate) -> Array: arrays = tuple(self.rec(arr) for arr in expr.arrays) return Concatenate(arrays=arrays, axis=expr.axis, - axes=expr.axes, tags=expr.tags) + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_roll(self, expr: Roll) -> Array: return Roll(array=self.rec(expr.array), shift=expr.shift, axis=expr.axis, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation) -> Array: return AxisPermutation(array=self.rec(expr.array), axis_permutation=expr.axis_permutation, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase) -> Array: return type(expr)(self.rec(expr.array), indices=self.rec_idx_or_size_tuple(expr.indices), axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_basic_index(self, expr: BasicIndex) -> Array: return self._map_index_base(expr) @@ -331,7 +338,8 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: data=expr.data, shape=self.rec_idx_or_size_tuple(expr.shape), axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_size_param(self, expr: SizeParam) -> Array: assert expr.name is not None @@ -343,13 +351,15 @@ def map_einsum(self, expr: Einsum) -> Array: axes=expr.axes, redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, index_to_access_descr=expr.index_to_access_descr, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_named_array(self, expr: NamedArray) -> Array: return type(expr)(self.rec(expr._container), expr.name, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> DictOfNamedArrays: @@ -377,14 +387,16 @@ def map_loopy_call_result(self, expr: LoopyCallResult) -> Array: container=rec_container, name=expr.name, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_reshape(self, expr: Reshape) -> Array: return Reshape(self.rec(expr.array), newshape=self.rec_idx_or_size_tuple(expr.newshape), order=expr.order, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_distributed_send_ref_holder( self, expr: DistributedSendRefHolder) -> Array: @@ -400,7 +412,8 @@ def map_distributed_recv(self, expr: DistributedRecv) -> Array: return DistributedRecv( src_rank=expr.src_rank, comm_tag=expr.comm_tag, shape=self.rec_idx_or_size_tuple(expr.shape), - dtype=expr.dtype, tags=expr.tags, axes=expr.axes) + dtype=expr.dtype, tags=expr.tags, axes=expr.axes, + non_equality_tags=expr.non_equality_tags) @memoize_method def map_function_definition(self, @@ -492,7 +505,8 @@ def map_index_lambda(self, expr: IndexLambda, bindings=bindings, axes=expr.axes, var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_placeholder(self, expr: Placeholder, *args: Any, **kwargs: Any) -> Array: assert expr.name is not None @@ -501,37 +515,43 @@ def map_placeholder(self, expr: Placeholder, *args: Any, **kwargs: Any) -> Array *args, **kwargs), dtype=expr.dtype, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack, *args: Any, **kwargs: Any) -> Array: arrays = tuple(self.rec(arr, *args, **kwargs) for arr in expr.arrays) - return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags) + return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_concatenate(self, expr: Concatenate, *args: Any, **kwargs: Any) -> Array: arrays = tuple(self.rec(arr, *args, **kwargs) for arr in expr.arrays) return Concatenate(arrays=arrays, axis=expr.axis, - axes=expr.axes, tags=expr.tags) + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_roll(self, expr: Roll, *args: Any, **kwargs: Any) -> Array: return Roll(array=self.rec(expr.array, *args, **kwargs), shift=expr.shift, axis=expr.axis, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation, *args: Any, **kwargs: Any) -> Array: return AxisPermutation(array=self.rec(expr.array, *args, **kwargs), axis_permutation=expr.axis_permutation, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase, *args: Any, **kwargs: Any) -> Array: return type(expr)(self.rec(expr.array, *args, **kwargs), indices=self.rec_idx_or_size_tuple(expr.indices, *args, **kwargs), axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_basic_index(self, expr: BasicIndex, *args: Any, **kwargs: Any) -> Array: return self._map_index_base(expr, *args, **kwargs) @@ -555,7 +575,8 @@ def map_data_wrapper(self, expr: DataWrapper, data=expr.data, shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs), axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_size_param(self, expr: SizeParam, *args: Any, **kwargs: Any) -> Array: assert expr.name is not None @@ -567,13 +588,15 @@ def map_einsum(self, expr: Einsum, *args: Any, **kwargs: Any) -> Array: axes=expr.axes, redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, index_to_access_descr=expr.index_to_access_descr, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_named_array(self, expr: NamedArray, *args: Any, **kwargs: Any) -> Array: return type(expr)(self.rec(expr._container, *args, **kwargs), expr.name, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_dict_of_named_arrays(self, expr: DictOfNamedArrays, *args: Any, **kwargs: Any) -> DictOfNamedArrays: @@ -613,7 +636,8 @@ def map_reshape(self, expr: Reshape, *args, **kwargs), order=expr.order, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder, *args: Any, **kwargs: Any) -> Array: @@ -623,14 +647,16 @@ def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder, dest_rank=expr.send.dest_rank, comm_tag=expr.send.comm_tag), self.rec(expr.passthrough_data, *args, **kwargs), - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_distributed_recv(self, expr: DistributedRecv, *args: Any, **kwargs: Any) -> Array: return DistributedRecv( src_rank=expr.src_rank, comm_tag=expr.comm_tag, shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs), - dtype=expr.dtype, tags=expr.tags, axes=expr.axes) + dtype=expr.dtype, tags=expr.tags, axes=expr.axes, + non_equality_tags=expr.non_equality_tags) def map_function_definition(self, expr: FunctionDefinition, *args: Any, **kwargs: Any) -> FunctionDefinition: diff --git a/test/test_pytato.py b/test/test_pytato.py index be122a925..8c0461685 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -857,6 +857,43 @@ def test_created_at(): # }}} + # {{{ Make sure only a single CreatedAt tag is created + + old_tag = tag + + res1 = res1 + res2 + + created_tag = frozenset({tag + for tag in res1.non_equality_tags + if isinstance(tag, CreatedAt)}) + + assert len(created_tag) == 1 + + tag, = created_tag + + # Tag should be recreated + assert tag != old_tag + + # }}} + + # {{{ Make sure that copying preserves the tag + + old_tag = tag + + res1_new = pt.transform.map_and_copy(res1, lambda x: x) + + created_tag = frozenset({tag + for tag in res1_new.non_equality_tags + if isinstance(tag, CreatedAt)}) + + assert len(created_tag) == 1 + + tag, = created_tag + + assert old_tag == tag + + # }}} + def test_pickling_and_unpickling_is_equal(): from testlib import RandomDAGContext, make_random_dag From 176595dd30381f71dbda40066084c1641be76508 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 17 Nov 2023 10:00:47 -0600 Subject: [PATCH 074/178] explicitly enable/disable traceback --- pytato/__init__.py | 2 ++ pytato/array.py | 16 +++++++++++++--- test/test_pytato.py | 16 ++++++++++++++++ 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/pytato/__init__.py b/pytato/__init__.py index 572e4a7ab..5255820af 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -113,6 +113,7 @@ def set_debug_enabled(flag: bool) -> None: rewrite_einsums_with_no_broadcasts) from pytato.transform.metadata import unify_axes_tags from pytato.function import trace_call +from pytato.array import enable_traceback_tag __all__ = ( "dtype", @@ -183,4 +184,5 @@ def set_debug_enabled(flag: bool) -> None: # sub-modules "analysis", "tags", "transform", "function", + "enable_traceback_tag", ) diff --git a/pytato/array.py b/pytato/array.py index fb32fb7f5..bcfaa3a3a 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -145,11 +145,12 @@ .. autoclass:: EinsumReductionAxis .. autoclass:: NormalizedSlice -Internal classes for traceback ------------------------------- +Traceback functionality +----------------------- Please consider these undocumented and subject to change at any time. +.. autofunction:: enable_traceback_tag .. class:: _PytatoFrameSummary .. class:: _PytatoStackSummary @@ -1838,11 +1839,20 @@ def __repr__(self) -> str: return "\n " + "\n ".join([str(f) for f in self.frames]) +_ENABLE_TRACEBACK_TAG = False + + +def enable_traceback_tag(enable: bool = True) -> None: + """Enable or disable the traceback tag.""" + global _ENABLE_TRACEBACK_TAG + _ENABLE_TRACEBACK_TAG = enable + + def _get_created_at_tag() -> Optional[Tag]: import traceback from pytato.tags import CreatedAt - if not __debug__: + if not _ENABLE_TRACEBACK_TAG: return None frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) diff --git a/test/test_pytato.py b/test/test_pytato.py index 8c0461685..674f640d6 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -798,6 +798,8 @@ def test_einsum_dot_axes_has_correct_dim(): def test_created_at(): + pt.enable_traceback_tag() + a = pt.make_placeholder("a", (10, 10), "float64") b = pt.make_placeholder("b", (10, 10), "float64") @@ -894,6 +896,20 @@ def test_created_at(): # }}} + # {{{ Test disabling traceback creation + + pt.enable_traceback_tag(False) + + a = pt.make_placeholder("a", (10, 10), "float64") + + created_tag = frozenset({tag + for tag in a.non_equality_tags + if isinstance(tag, CreatedAt)}) + + assert len(created_tag) == 0 + + # }}} + def test_pickling_and_unpickling_is_equal(): from testlib import RandomDAGContext, make_random_dag From 40557e9f13ac3d7167dac76f93e87a2d70929b84 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 17 Nov 2023 10:41:15 -0600 Subject: [PATCH 075/178] add hash test --- test/test_pytato.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 674f640d6..127f0b7dc 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -812,15 +812,20 @@ def test_created_at(): # CreatedAt tags. res3 = a+b; res4 = a+b # noqa: E702 - # {{{ Check that CreatedAt tags are handled correctly for equality + # {{{ Check that CreatedAt tags are handled correctly for equality/hashing assert res1 == res2 == res3 == res4 + assert hash(res1) == hash(res2) == hash(res3) == hash(res4) assert res1.non_equality_tags != res2.non_equality_tags assert res3.non_equality_tags == res4.non_equality_tags - assert res1.tags == res2.tags - assert res3.tags == res4.tags + assert hash(res1.non_equality_tags) != hash(res2.non_equality_tags) + assert hash(res3.non_equality_tags) == hash(res4.non_equality_tags) + + assert res1.tags == res2.tags == res3.tags == res4.tags + + assert hash(res1.tags) == hash(res2.tags) == hash(res3.tags) == hash(res4.tags) # }}} From 524049564eb60de3083e87a42cf74fdec267495f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 17 Nov 2023 11:16:53 -0600 Subject: [PATCH 076/178] undo more unnecessary changes --- pytato/array.py | 6 +++--- pytato/transform/__init__.py | 8 ++++---- pytato/visualization/dot.py | 1 - test/test_pytato.py | 14 ++++++-------- 4 files changed, 13 insertions(+), 16 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index bcfaa3a3a..adefe787b 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -277,7 +277,7 @@ def normalize_shape_component( # }}} -# {{{ array interface +# {{{ array inteface ConvertibleToIndexExpr = Union[int, slice, "Array", None, EllipsisType] IndexExpr = Union[IntegralT, "NormalizedSlice", "Array", None, EllipsisType] @@ -386,7 +386,7 @@ class Array(Taggable): :class:`~pytato.array.IndexLambda` is used to produce references to named arrays. Since any array that needs to be referenced in this way needs to obey this restriction anyway, - a decision was made to require the same of *all* array expressions. + a decision was made to requir the same of *all* array expressions. .. attribute:: dtype @@ -1912,7 +1912,7 @@ def roll(a: Array, shift: int, axis: Optional[int] = None) -> Array: if axis is None: if a.ndim > 1: raise NotImplementedError( - "shifting along more than one dimension is unsupported") + "shifing along more than one dimension is unsupported") else: axis = 0 diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index f04f68b11..041e8df01 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -432,8 +432,8 @@ def map_function_definition(self, def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: return Call(self.map_function_definition(expr.function), immutabledict({name: self.rec(bnd) - for name, bnd in sorted(expr.bindings.items())}), - tags=expr.tags + for name, bnd in expr.bindings.items()}), + tags=expr.tags, ) def map_named_call_result(self, expr: NamedCallResult) -> Array: @@ -668,7 +668,7 @@ def map_call(self, expr: Call, *args: Any, **kwargs: Any) -> AbstractResultWithNamedArrays: return Call(self.map_function_definition(expr.function, *args, **kwargs), immutabledict({name: self.rec(bnd, *args, **kwargs) - for name, bnd in sorted(expr.bindings.items())}), + for name, bnd in expr.bindings.items()}), tags=expr.tags, ) @@ -1566,7 +1566,7 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: - MPMS materialization strategy is a greedy materialization algorithm in which any node with more than 1 materialized predecessors and more than - 1 successor is materialized. + 1 successors is materialized. - Materializing here corresponds to tagging a node with :class:`~pytato.tags.ImplStored`. - Does not attempt to materialize sub-expressions in diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 3984d986e..2798a5e2a 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -190,7 +190,6 @@ def handle_unsupported_array(self, # type: ignore[override] for field in attrs.fields(type(expr)): # type: ignore[misc] if field.name in info.fields: continue - attr = getattr(expr, field.name) if isinstance(attr, Array): diff --git a/test/test_pytato.py b/test/test_pytato.py index 127f0b7dc..7672999e7 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -449,12 +449,12 @@ def test_array_dot_repr(): x = pt.make_placeholder("x", (10, 4), np.int64) y = pt.make_placeholder("y", (10, 4), np.int64) - def _assert_repr(ary: pt.Array, expected_repr: str): + def _assert_stripped_repr(ary: pt.Array, expected_repr: str): expected_str = "".join([c for c in expected_repr if c not in [" ", "\n"]]) result_str = "".join([c for c in repr(ary)if c not in [" ", "\n"]]) assert expected_str == result_str - _assert_repr( + _assert_stripped_repr( 3*x + 4*y, """ IndexLambda( @@ -483,7 +483,7 @@ def _assert_repr(ary: pt.Array, expected_repr: str): dtype='int64', name='y')})})""") - _assert_repr( + _assert_stripped_repr( pt.roll(x.reshape(2, 20).reshape(-1), 3), """ Roll( @@ -495,7 +495,7 @@ def _assert_repr(ary: pt.Array, expected_repr: str): newshape=(40), order='C'), shift=3, axis=0)""") - _assert_repr(y * pt.not_equal(x, 3), + _assert_stripped_repr(y * pt.not_equal(x, 3), """ IndexLambda( shape=(10, 4), @@ -515,7 +515,7 @@ def _assert_repr(ary: pt.Array, expected_repr: str): bindings={'_in0': Placeholder(shape=(10, 4), dtype='int64', name='x')})})""") - _assert_repr( + _assert_stripped_repr( x[y[:, 2:3], x[2, :]], """ AdvancedIndexInContiguousAxes( @@ -530,7 +530,7 @@ def _assert_repr(ary: pt.Array, expected_repr: str): name='x'), indices=(2, NormalizedSlice(start=0, stop=4, step=1)))))""") - _assert_repr( + _assert_stripped_repr( pt.stack([x[y[:, 2:3], x[2, :]].T, y[x[:, 2:3], y[2, :]].T]), """ Stack( @@ -819,12 +819,10 @@ def test_created_at(): assert res1.non_equality_tags != res2.non_equality_tags assert res3.non_equality_tags == res4.non_equality_tags - assert hash(res1.non_equality_tags) != hash(res2.non_equality_tags) assert hash(res3.non_equality_tags) == hash(res4.non_equality_tags) assert res1.tags == res2.tags == res3.tags == res4.tags - assert hash(res1.tags) == hash(res2.tags) == hash(res3.tags) == hash(res4.tags) # }}} From 9338f0be61a1213563d3851467aca4e80a2cdba7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 21 Nov 2023 15:10:23 -0600 Subject: [PATCH 077/178] more lint fixes --- pytato/array.py | 5 +++-- pytato/distributed/nodes.py | 6 ++++-- pytato/visualization/dot.py | 10 ++++++---- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 2cefb9078..c188ec72c 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -450,8 +450,9 @@ class Array(Taggable): tags: FrozenSet[Tag] = attrs.field(kw_only=True) # These are automatically excluded from equality in EqualityComparer - non_equality_tags: FrozenSet[Tag] = attrs.field(kw_only=True, hash=False, - default=frozenset()) + non_equality_tags: FrozenSet[Optional[Tag]] = attrs.field(kw_only=True, + hash=False, + default=frozenset()) _mapper_method: ClassVar[str] diff --git a/pytato/distributed/nodes.py b/pytato/distributed/nodes.py index 465bda312..e95217b82 100644 --- a/pytato/distributed/nodes.py +++ b/pytato/distributed/nodes.py @@ -149,8 +149,10 @@ class DistributedSendRefHolder(Array): _mapper_method: ClassVar[str] = "map_distributed_send_ref_holder" def __init__(self, send: DistributedSend, passthrough_data: Array, - tags: FrozenSet[Tag] = frozenset()) -> None: - super().__init__(axes=passthrough_data.axes, tags=tags) + tags: FrozenSet[Tag] = frozenset(), + non_equality_tags: FrozenSet[Optional[Tag]] = frozenset()) -> None: + super().__init__(axes=passthrough_data.axes, tags=tags, + non_equality_tags=non_equality_tags) object.__setattr__(self, "send", send) object.__setattr__(self, "passthrough_data", passthrough_data) diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index b4128f1fc..69b7cd21a 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -144,11 +144,11 @@ def emit_subgraph(sg: _SubgraphTree) -> None: @attrs.define class _DotNodeInfo: title: str - fields: Dict[str, str] + fields: Dict[str, Any] edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] -def stringify_tags(tags: FrozenSet[Tag]) -> str: +def stringify_tags(tags: FrozenSet[Optional[Tag]]) -> str: components = sorted(str(elem) for elem in tags) return "{" + ", ".join(components) + "}" @@ -373,14 +373,16 @@ def _stringify_created_at(non_equality_tags: FrozenSet[Tag]) -> str: return "" -def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, str], +def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, Any], dot_node_id: str, color: str = "white") -> None: td_attrib = 'border="0"' table_attrib = 'border="0" cellborder="1" cellspacing="0"' rows = [f"{dot_escape(title)}"] - non_equality_tags = fields.pop("non_equality_tags", frozenset()) + non_equality_tags: FrozenSet[Any] = fields.pop("non_equality_tags", frozenset()) + + print(non_equality_tags) tooltip = dot_escape_leave_space(_stringify_created_at(non_equality_tags)) for name, field in fields.items(): From f5cb92ffe209ff506e5956ae57975d558dbb537a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 21 Nov 2023 15:56:44 -0600 Subject: [PATCH 078/178] run all examples, fix demo_distributed_node_duplication --- .github/workflows/ci.yml | 2 +- examples/demo_distributed_node_duplication.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6fd07337c..07d11a052 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -77,7 +77,7 @@ jobs: . ci-support-v0 build_py_project_in_conda_env pip install pytest # for advection.py - run_examples + run_examples --no-require-main docs: name: Documentation diff --git a/examples/demo_distributed_node_duplication.py b/examples/demo_distributed_node_duplication.py index 39307ccfb..9dd0670ae 100644 --- a/examples/demo_distributed_node_duplication.py +++ b/examples/demo_distributed_node_duplication.py @@ -1,15 +1,20 @@ """ An example to demonstrate the behavior of -:func:`pytato.find_distrbuted_partition`. One of the key characteristic of the -partitioning routine is to recompute expressions that appear in the multiple +:func:`pytato.find_distributed_partition`. One of the key characteristics of the +partitioning routine is to recompute expressions that appear in multiple partitions but are not materialized. """ import pytato as pt import numpy as np +from mpi4py import MPI + +comm = MPI.COMM_WORLD size = 2 rank = 0 +pt.enable_traceback_tag() + x1 = pt.make_placeholder("x1", shape=(10, 4), dtype=np.float64) x2 = pt.make_placeholder("x2", shape=(10, 4), dtype=np.float64) x3 = pt.make_placeholder("x3", shape=(10, 4), dtype=np.float64) @@ -30,7 +35,7 @@ out = tmp2 + recv result = pt.make_dict_of_named_arrays({"out": out}) -partitions = pt.find_distributed_partition(result) +partitions = pt.find_distributed_partition(comm, result) # Visualize *partitions* to see that each of the two partitions contains a node # named 'tmp2'. From 36166c6e85a55e6530c98d80d6a759e2e7df6fd5 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 21 Nov 2023 16:15:25 -0600 Subject: [PATCH 079/178] enable CreatedAt for distributed nodes --- examples/mpi-distributed.py | 2 ++ examples/visualization.py | 2 ++ pytato/array.py | 9 ++++++--- pytato/distributed/nodes.py | 13 +++++++++---- pytato/visualization/dot.py | 1 - 5 files changed, 19 insertions(+), 8 deletions(-) diff --git a/examples/mpi-distributed.py b/examples/mpi-distributed.py index dd29a82ab..ce8bcdab4 100644 --- a/examples/mpi-distributed.py +++ b/examples/mpi-distributed.py @@ -15,6 +15,8 @@ def main(): + pt.enable_traceback_tag() + ctx = cl.create_some_context() queue = cl.CommandQueue(ctx) diff --git a/examples/visualization.py b/examples/visualization.py index ac71e6060..f18262cb5 100755 --- a/examples/visualization.py +++ b/examples/visualization.py @@ -17,6 +17,8 @@ def main(): + pt.enable_traceback_tag() + n = pt.make_size_param("n") array = pt.make_placeholder(name="array", shape=n, dtype=np.float64) stack = pt.stack([array, 2*array, array + 6]) diff --git a/pytato/array.py b/pytato/array.py index c188ec72c..4b5373505 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1091,6 +1091,7 @@ def with_tagged_reduction(self, (new_redn_axis_to_redn_descr), tags=self.tags, index_to_access_descr=self.index_to_access_descr, + non_equality_tags=self.non_equality_tags, ) @@ -1297,6 +1298,7 @@ def einsum(subscripts: str, *operands: Array, ), redn_axis_to_redn_descr=immutabledict(redn_axis_to_redn_descr), index_to_access_descr=index_to_descr, + non_equality_tags=frozenset({_get_created_at_tag()}), ) # }}} @@ -1825,11 +1827,12 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: def short_str(self, maxlen: int = 100) -> str: from os.path import dirname - # Find the first file in the frames that is not in pytato's pytato/ - # directory. + # Find the first file in the frames that is not in pytato's internal + # directories. for frame in reversed(self.frames): frame_dir = dirname(frame.filename) - if not frame_dir.endswith("pytato"): + if (not frame_dir.endswith("pytato") + and not frame_dir.endswith("pytato/distributed")): return frame.short_str(maxlen) # Fallback in case we don't find any file that is not in the pytato/ diff --git a/pytato/distributed/nodes.py b/pytato/distributed/nodes.py index e95217b82..617876c97 100644 --- a/pytato/distributed/nodes.py +++ b/pytato/distributed/nodes.py @@ -64,7 +64,8 @@ from pytato.array import ( Array, _SuppliedShapeAndDtypeMixin, ShapeType, AxesT, - _get_default_axes, ConvertibleToShape, normalize_shape) + _get_default_axes, ConvertibleToShape, normalize_shape, + _get_created_at_tag) CommTagType = Hashable @@ -170,13 +171,15 @@ def copy(self, **kwargs: Any) -> DistributedSendRefHolder: send = kwargs.pop("send", self.send) passthrough_data = kwargs.pop("passthrough_data", self.passthrough_data) tags = kwargs.pop("tags", self.tags) + non_equality_tags = kwargs.pop("non_equality_tags", self.non_equality_tags) if kwargs: raise ValueError("Cannot assign" f" DistributedSendRefHolder.'{set(kwargs)}'") return DistributedSendRefHolder(send, passthrough_data, - tags) + tags, + non_equality_tags) # }}} @@ -238,7 +241,8 @@ def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagT return DistributedSendRefHolder( send=DistributedSend(data=sent_data, dest_rank=dest_rank, comm_tag=comm_tag, tags=send_tags), - passthrough_data=stapled_to, tags=ref_holder_tags) + passthrough_data=stapled_to, tags=ref_holder_tags, + non_equality_tags=frozenset({_get_created_at_tag()})) def make_distributed_recv(src_rank: int, comm_tag: CommTagType, @@ -255,7 +259,8 @@ def make_distributed_recv(src_rank: int, comm_tag: CommTagType, dtype = np.dtype(dtype) return DistributedRecv( src_rank=src_rank, comm_tag=comm_tag, shape=shape, dtype=dtype, - tags=tags, axes=axes) + tags=tags, axes=axes, + non_equality_tags=frozenset({_get_created_at_tag()})) # }}} diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 69b7cd21a..9c59c4e4a 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -382,7 +382,6 @@ def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, Any], non_equality_tags: FrozenSet[Any] = fields.pop("non_equality_tags", frozenset()) - print(non_equality_tags) tooltip = dot_escape_leave_space(_stringify_created_at(non_equality_tags)) for name, field in fields.items(): From 4c3b06a82a9e9b2eefc3307d3ed19d29efcb00d2 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 29 Nov 2023 12:43:11 -0600 Subject: [PATCH 080/178] undo MPI tag ordering --- pytato/distributed/tags.py | 49 +++++++++++++++++++++----------------- test/test_distributed.py | 10 ++------ 2 files changed, 29 insertions(+), 30 deletions(-) diff --git a/pytato/distributed/tags.py b/pytato/distributed/tags.py index 1420b7532..f9eb3b7ec 100644 --- a/pytato/distributed/tags.py +++ b/pytato/distributed/tags.py @@ -31,7 +31,7 @@ """ -from typing import TYPE_CHECKING, Tuple, TypeVar +from typing import TYPE_CHECKING, Tuple, FrozenSet, Optional, TypeVar from pytato.distributed.partition import DistributedGraphPartition @@ -63,40 +63,45 @@ def number_distributed_tags( This is a potentially heavyweight MPI-collective operation on *mpi_communicator*. """ - from pytools import flatten - - tags = tuple([ + tags = frozenset({ recv.comm_tag for part in partition.parts.values() for recv in part.name_to_recv_node.values() - ] + [ + } | { send.comm_tag for part in partition.parts.values() for sends in part.name_to_send_nodes.values() - for send in sends]) + for send in sends}) + + from mpi4py import MPI + + def set_union( + set_a: FrozenSet[T], set_b: FrozenSet[T], + mpi_data_type: Optional[MPI.Datatype]) -> FrozenSet[T]: + assert mpi_data_type is None + assert isinstance(set_a, frozenset) + assert isinstance(set_b, frozenset) + + return set_a | set_b root_rank = 0 - all_tags = mpi_communicator.gather(tags, root=root_rank) + set_union_mpi_op = MPI.Op.Create( + # type ignore reason: mpi4py misdeclares op functions as returning + # None. + set_union, # type: ignore[arg-type] + commute=True) + try: + all_tags = mpi_communicator.reduce( + tags, set_union_mpi_op, root=root_rank) + finally: + set_union_mpi_op.Free() if mpi_communicator.rank == root_rank: sym_tag_to_int_tag = {} next_tag = base_tag - assert isinstance(all_tags, list) - assert len(all_tags) == mpi_communicator.size - - # First previous version - # for sym_tag in sorted(all_tags, key=lambda tag: repr(tag)): - # sym_tag_to_int_tag[sym_tag] = next_tag - # next_tag += 1 - # - # - # Second previous version - # for sym_tag in flatten(all_tags): # type: ignore[no-untyped-call] - # if sym_tag not in sym_tag_to_int_tag: - # sym_tag_to_int_tag[sym_tag] = next_tag - # next_tag += 1 - # Current main version + assert isinstance(all_tags, frozenset) + for sym_tag in all_tags: sym_tag_to_int_tag[sym_tag] = next_tag next_tag += 1 diff --git a/test/test_distributed.py b/test/test_distributed.py index 6e2e34376..925d2e070 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -266,7 +266,7 @@ def _do_test_distributed_execution_random_dag(ctx_factory): ntests = 10 for i in range(ntests): seed = 120 + i - print(f"Step {i} {seed=}") + print(f"Step {i} {seed}") # {{{ compute value with communication @@ -278,13 +278,7 @@ def gen_comm(rdagc): nonlocal comm_tag comm_tag += 1 - - if comm_tag % 5 == 1 or 1: - tag = (comm_tag, frozenset([_RandomDAGTag, "a", comm_tag])) - elif comm_tag % 5 == 2: - tag = (comm_tag, (_RandomDAGTag, "b")) - else: - tag = (comm_tag, _RandomDAGTag) # noqa: B023 + tag = (comm_tag, _RandomDAGTag) # noqa: B023 inner = make_random_dag(rdagc) return pt.staple_distributed_send( From 06503b1cc548b6ae6061900bcc933453431e269e Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 2 Feb 2024 11:00:41 -0600 Subject: [PATCH 081/178] get precise traceback of array creation --- pytato/array.py | 47 +++++++++++++++++++++++++++++++++++++++-------- pytato/cmath.py | 2 +- pytato/utils.py | 33 +++++++++++++++++++-------------- 3 files changed, 59 insertions(+), 23 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 4b5373505..ffbee5b21 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -508,7 +508,11 @@ def __getitem__(self, slice_spec = (slice_spec,) from pytato.utils import _index_into - return _index_into(self, slice_spec) + return _index_into( + self, + slice_spec, + tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()})) @property def ndim(self) -> int: @@ -553,13 +557,20 @@ def _binary_op(self, # }}} + tags = _get_default_tags() + non_equality_tags = frozenset({_get_created_at_tag(stacklevel=2)}) + import pytato.utils as utils if reverse: result = utils.broadcast_binary_op(other, self, op, - get_result_type) + get_result_type, + tags=tags, + non_equality_tags=non_equality_tags) else: result = utils.broadcast_binary_op(self, other, op, - get_result_type) + get_result_type, + tags=tags, + non_equality_tags=non_equality_tags) assert isinstance(result, Array) return result @@ -579,6 +590,7 @@ def _unary_op(self, op: Any) -> Array: bindings=bindings, tags=_get_default_tags(), axes=_get_default_axes(self.ndim), + non_equality_tags=frozenset({_get_created_at_tag(stacklevel=2)}), var_to_reduction_descr=immutabledict()) __mul__ = partialmethod(_binary_op, operator.mul) @@ -1852,15 +1864,25 @@ def enable_traceback_tag(enable: bool = True) -> None: _ENABLE_TRACEBACK_TAG = enable -def _get_created_at_tag() -> Optional[Tag]: +def _get_created_at_tag(stacklevel: int = 1) -> Optional[Tag]: + """ + Get a :class:`CreatedAt` tag storing the stack trace of an array's creation. + + :param stacklevel: the number of stack levels above this call to record as the + array creation location + """ import traceback from pytato.tags import CreatedAt if not _ENABLE_TRACEBACK_TAG: return None + # Drop the stack levels corresponding to extract_stack() and any additional + # levels specified via stacklevel + stack = traceback.extract_stack()[:-(1+stacklevel)] + frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) - for s in traceback.extract_stack()) + for s in stack) return CreatedAt(_PytatoStackSummary(frames)) @@ -2376,7 +2398,10 @@ def _compare(x1: ArrayOrScalar, x2: ArrayOrScalar, which: str) -> Union[Array, b # '_compare' returns a bool. return utils.broadcast_binary_op(x1, x2, lambda x, y: prim.Comparison(x, which, y), - lambda x, y: np.dtype(np.bool_) + lambda x, y: np.dtype(np.bool_), + tags=_get_default_tags(), + non_equality_tags=frozenset({ + _get_created_at_tag(stacklevel=2)}), ) # type: ignore[return-value] @@ -2436,7 +2461,10 @@ def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Union[Array, bool]: import pytato.utils as utils return utils.broadcast_binary_op(x1, x2, lambda x, y: prim.LogicalOr((x, y)), - lambda x, y: np.dtype(np.bool_) + lambda x, y: np.dtype(np.bool_), + tags=_get_default_tags(), + non_equality_tags=frozenset({ + _get_created_at_tag()}), ) # type: ignore[return-value] @@ -2450,7 +2478,10 @@ def logical_and(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Union[Array, bool]: import pytato.utils as utils return utils.broadcast_binary_op(x1, x2, lambda x, y: prim.LogicalAnd((x, y)), - lambda x, y: np.dtype(np.bool_) + lambda x, y: np.dtype(np.bool_), + tags=_get_default_tags(), + non_equality_tags=frozenset({ + _get_created_at_tag()}), ) # type: ignore[return-value] diff --git a/pytato/cmath.py b/pytato/cmath.py index 9f8e8c7fa..3afbe82ed 100644 --- a/pytato/cmath.py +++ b/pytato/cmath.py @@ -115,7 +115,7 @@ def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...], tuple(sym_args)), shape=shape, dtype=ret_dtype, bindings=immutabledict(bindings), tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + non_equality_tags=frozenset({_get_created_at_tag(stacklevel=2)}), axes=_get_default_axes(len(shape)), var_to_reduction_descr=immutabledict(), ) diff --git a/pytato/utils.py b/pytato/utils.py index 94c43938c..e5463d8f5 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -27,7 +27,7 @@ import pymbolic.primitives as prim from typing import (Tuple, List, Union, Callable, Any, Sequence, Dict, - Optional, Iterable, TypeVar) + Optional, Iterable, TypeVar, FrozenSet) from pytato.array import (Array, ShapeType, IndexLambda, SizeParam, ShapeComponent, DtypeOrScalar, ArrayOrScalar, BasicIndex, AdvancedIndexInContiguousAxes, @@ -38,6 +38,7 @@ SCALAR_CLASSES, INT_CLASSES, BoolT, ScalarType) from pytools import UniqueNameGenerator from pytato.transform import Mapper +from pytools.tag import Tag from immutabledict import immutabledict @@ -178,9 +179,10 @@ def update_bindings_and_get_broadcasted_expr(arr: ArrayOrScalar, def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501 + tags: FrozenSet[Tag], + non_equality_tags: FrozenSet[Optional[Tag]], ) -> ArrayOrScalar: - from pytato.array import (_get_default_axes, _get_default_tags, - _get_created_at_tag) + from pytato.array import _get_default_axes if isinstance(a1, SCALAR_CLASSES): a1 = np.dtype(type(a1)).type(a1) @@ -207,8 +209,8 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, shape=result_shape, dtype=result_dtype, bindings=immutabledict(bindings), - tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + tags=tags, + non_equality_tags=non_equality_tags, var_to_reduction_descr=immutabledict(), axes=_get_default_axes(len(result_shape))) @@ -475,10 +477,13 @@ def _normalized_slice_len(slice_: NormalizedSlice) -> ShapeComponent: # }}} -def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Array: +def _index_into( + ary: Array, + indices: Tuple[ConvertibleToIndexExpr, ...], + tags: FrozenSet[Tag], + non_equality_tags: FrozenSet[Optional[Tag]]) -> Array: from pytato.diagnostic import CannotBroadcastError - from pytato.array import (_get_default_axes, _get_default_tags, - _get_created_at_tag) + from pytato.array import _get_default_axes # {{{ handle ellipsis @@ -564,24 +569,24 @@ def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Arra return AdvancedIndexInNoncontiguousAxes( ary, tuple(normalized_indices), - tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + tags=tags, + non_equality_tags=non_equality_tags, axes=_get_default_axes(len(array_idx_shape) + len(i_basic_indices))) else: return AdvancedIndexInContiguousAxes( ary, tuple(normalized_indices), - tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + tags=tags, + non_equality_tags=non_equality_tags, axes=_get_default_axes(len(array_idx_shape) + len(i_basic_indices))) else: # basic indexing expression return BasicIndex(ary, tuple(normalized_indices), - tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + tags=tags, + non_equality_tags=non_equality_tags, axes=_get_default_axes( len([idx for idx in normalized_indices From ab87fbf4ad2fb3d1741f31055a5665e94ec6926e Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 2 Feb 2024 13:28:02 -0600 Subject: [PATCH 082/178] partialmethod doesn't introduce a stack frame --- pytato/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index ffbee5b21..66d566576 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -558,7 +558,7 @@ def _binary_op(self, # }}} tags = _get_default_tags() - non_equality_tags = frozenset({_get_created_at_tag(stacklevel=2)}) + non_equality_tags = frozenset({_get_created_at_tag()}) import pytato.utils as utils if reverse: @@ -590,7 +590,7 @@ def _unary_op(self, op: Any) -> Array: bindings=bindings, tags=_get_default_tags(), axes=_get_default_axes(self.ndim), - non_equality_tags=frozenset({_get_created_at_tag(stacklevel=2)}), + non_equality_tags=frozenset({_get_created_at_tag()}), var_to_reduction_descr=immutabledict()) __mul__ = partialmethod(_binary_op, operator.mul) From d8df5f83b16762e2f9e10f4f067bad2401c02c35 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 6 Feb 2024 10:15:53 -0600 Subject: [PATCH 083/178] add support for make_distributed_send_ref_holder --- pytato/distributed/nodes.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/pytato/distributed/nodes.py b/pytato/distributed/nodes.py index b5796d888..07befad64 100644 --- a/pytato/distributed/nodes.py +++ b/pytato/distributed/nodes.py @@ -231,12 +231,16 @@ def make_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagTyp def make_distributed_send_ref_holder( send: DistributedSend, passthrough_data: Array, - tags: FrozenSet[Tag] = frozenset() + tags: FrozenSet[Tag] = frozenset(), + non_equality_tags: FrozenSet[Optional[Tag]] = frozenset(), ) -> DistributedSendRefHolder: """Make a :class:`DistributedSendRefHolder` object.""" + if not non_equality_tags: + non_equality_tags = frozenset({_get_created_at_tag()}) return DistributedSendRefHolder( send=send, passthrough_data=passthrough_data, - tags=(tags | _get_default_tags())) + tags=(tags | _get_default_tags()), + non_equality_tags=non_equality_tags) def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagType, @@ -251,7 +255,8 @@ def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagT sent_data=sent_data, dest_rank=dest_rank, comm_tag=comm_tag, send_tags=send_tags), passthrough_data=stapled_to, - tags=ref_holder_tags) + tags=ref_holder_tags, + non_equality_tags=frozenset({_get_created_at_tag()})) def make_distributed_recv(src_rank: int, comm_tag: CommTagType, From e1b918156723d0c2f582a780c4ba6feb413e6c11 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 16 Feb 2024 10:44:27 -0600 Subject: [PATCH 084/178] add to MPMSMaterializer --- pytato/transform/__init__.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 57cf7d6ce..6de82bb29 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1344,14 +1344,16 @@ def map_index_lambda(self, expr: IndexLambda) -> MPMSMaterializerAccumulator: for bnd_name, bnd in sorted(children_rec.items())}), axes=expr.axes, var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], children_rec.values()) def map_stack(self, expr: Stack) -> MPMSMaterializerAccumulator: rec_arrays = [self.rec(ary) for ary in expr.arrays] new_expr = Stack(tuple(ary.expr for ary in rec_arrays), - expr.axis, axes=expr.axes, tags=expr.tags) + expr.axis, axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -1362,7 +1364,8 @@ def map_concatenate(self, expr: Concatenate) -> MPMSMaterializerAccumulator: new_expr = Concatenate(tuple(ary.expr for ary in rec_arrays), expr.axis, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], rec_arrays) @@ -1370,7 +1373,8 @@ def map_concatenate(self, expr: Concatenate) -> MPMSMaterializerAccumulator: def map_roll(self, expr: Roll) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) new_expr = Roll(rec_array.expr, expr.shift, expr.axis, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) @@ -1378,7 +1382,8 @@ def map_axis_permutation(self, expr: AxisPermutation ) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) new_expr = AxisPermutation(rec_array.expr, expr.axis_permutation, - axes=expr.axes, tags=expr.tags) + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) @@ -1396,7 +1401,8 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: for i in range( len(expr.indices))), axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -1410,7 +1416,8 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: def map_reshape(self, expr: Reshape) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) new_expr = Reshape(rec_array.expr, expr.newshape, - expr.order, axes=expr.axes, tags=expr.tags) + expr.order, axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -1423,7 +1430,8 @@ def map_einsum(self, expr: Einsum) -> MPMSMaterializerAccumulator: expr.redn_axis_to_redn_descr, expr.index_to_access_descr, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], From be9dcddd8c73475faf7dff21e11a9371e436eead Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Sat, 2 Mar 2024 08:47:37 -0600 Subject: [PATCH 085/178] Spew array tracing to stdout. --- pytato/array.py | 1 + pytato/transform/metadata.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index 2b1ec11b1..65d91857c 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1488,6 +1488,7 @@ class Reshape(IndexRemappingBase): if __debug__: def __attrs_post_init__(self) -> None: + assert self.non_equality_tags super().__attrs_post_init__() @property diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 7f015d739..780d96fd4 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -614,8 +614,12 @@ def rec(self, expr: ArrayOrNames) -> Any: assert expr_copy.ndim == expr.ndim for iaxis in range(expr.ndim): + axis_tags = self.axis_to_tags.get((expr, iaxis), []) + if len(axis_tags) == 0: + print(f"failed to infer axis {iaxis} of array of type {type(expr)}.") + print(f"{expr.non_equality_tags=}") expr_copy = expr_copy.with_tagged_axis( - iaxis, self.axis_to_tags.get((expr, iaxis), [])) + iaxis, axis_tags) # {{{ tag reduction descrs From ad0aa4c01f3d9aabcf8a53ea0cbf1ca5b7bce0f1 Mon Sep 17 00:00:00 2001 From: Matt Smith Date: Thu, 7 Mar 2024 12:44:10 -0600 Subject: [PATCH 086/178] Get precise traceback of array creation (#480) --- pytato/array.py | 47 +++++++++++++++++++++++++++++++++++++++-------- pytato/cmath.py | 2 +- pytato/utils.py | 33 +++++++++++++++++++-------------- 3 files changed, 59 insertions(+), 23 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 447b4d3d6..a88c0e948 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -508,7 +508,11 @@ def __getitem__(self, slice_spec = (slice_spec,) from pytato.utils import _index_into - return _index_into(self, slice_spec) + return _index_into( + self, + slice_spec, + tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()})) @property def ndim(self) -> int: @@ -553,13 +557,20 @@ def _binary_op(self, # }}} + tags = _get_default_tags() + non_equality_tags = frozenset({_get_created_at_tag()}) + import pytato.utils as utils if reverse: result = utils.broadcast_binary_op(other, self, op, - get_result_type) + get_result_type, + tags=tags, + non_equality_tags=non_equality_tags) else: result = utils.broadcast_binary_op(self, other, op, - get_result_type) + get_result_type, + tags=tags, + non_equality_tags=non_equality_tags) assert isinstance(result, Array) return result @@ -579,6 +590,7 @@ def _unary_op(self, op: Any) -> Array: bindings=bindings, tags=_get_default_tags(), axes=_get_default_axes(self.ndim), + non_equality_tags=frozenset({_get_created_at_tag()}), var_to_reduction_descr=immutabledict()) __mul__ = partialmethod(_binary_op, operator.mul) @@ -1852,15 +1864,25 @@ def enable_traceback_tag(enable: bool = True) -> None: _ENABLE_TRACEBACK_TAG = enable -def _get_created_at_tag() -> Optional[Tag]: +def _get_created_at_tag(stacklevel: int = 1) -> Optional[Tag]: + """ + Get a :class:`CreatedAt` tag storing the stack trace of an array's creation. + + :param stacklevel: the number of stack levels above this call to record as the + array creation location + """ import traceback from pytato.tags import CreatedAt if not _ENABLE_TRACEBACK_TAG: return None + # Drop the stack levels corresponding to extract_stack() and any additional + # levels specified via stacklevel + stack = traceback.extract_stack()[:-(1+stacklevel)] + frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) - for s in traceback.extract_stack()) + for s in stack) return CreatedAt(_PytatoStackSummary(frames)) @@ -2378,7 +2400,10 @@ def _compare(x1: ArrayOrScalar, x2: ArrayOrScalar, which: str) -> Union[Array, b # '_compare' returns a bool. return utils.broadcast_binary_op(x1, x2, lambda x, y: prim.Comparison(x, which, y), - lambda x, y: np.dtype(np.bool_) + lambda x, y: np.dtype(np.bool_), + tags=_get_default_tags(), + non_equality_tags=frozenset({ + _get_created_at_tag(stacklevel=2)}), ) # type: ignore[return-value] @@ -2438,7 +2463,10 @@ def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Union[Array, bool]: import pytato.utils as utils return utils.broadcast_binary_op(x1, x2, lambda x, y: prim.LogicalOr((x, y)), - lambda x, y: np.dtype(np.bool_) + lambda x, y: np.dtype(np.bool_), + tags=_get_default_tags(), + non_equality_tags=frozenset({ + _get_created_at_tag()}), ) # type: ignore[return-value] @@ -2452,7 +2480,10 @@ def logical_and(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Union[Array, bool]: import pytato.utils as utils return utils.broadcast_binary_op(x1, x2, lambda x, y: prim.LogicalAnd((x, y)), - lambda x, y: np.dtype(np.bool_) + lambda x, y: np.dtype(np.bool_), + tags=_get_default_tags(), + non_equality_tags=frozenset({ + _get_created_at_tag()}), ) # type: ignore[return-value] diff --git a/pytato/cmath.py b/pytato/cmath.py index 9f8e8c7fa..3afbe82ed 100644 --- a/pytato/cmath.py +++ b/pytato/cmath.py @@ -115,7 +115,7 @@ def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...], tuple(sym_args)), shape=shape, dtype=ret_dtype, bindings=immutabledict(bindings), tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + non_equality_tags=frozenset({_get_created_at_tag(stacklevel=2)}), axes=_get_default_axes(len(shape)), var_to_reduction_descr=immutabledict(), ) diff --git a/pytato/utils.py b/pytato/utils.py index 94c43938c..e5463d8f5 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -27,7 +27,7 @@ import pymbolic.primitives as prim from typing import (Tuple, List, Union, Callable, Any, Sequence, Dict, - Optional, Iterable, TypeVar) + Optional, Iterable, TypeVar, FrozenSet) from pytato.array import (Array, ShapeType, IndexLambda, SizeParam, ShapeComponent, DtypeOrScalar, ArrayOrScalar, BasicIndex, AdvancedIndexInContiguousAxes, @@ -38,6 +38,7 @@ SCALAR_CLASSES, INT_CLASSES, BoolT, ScalarType) from pytools import UniqueNameGenerator from pytato.transform import Mapper +from pytools.tag import Tag from immutabledict import immutabledict @@ -178,9 +179,10 @@ def update_bindings_and_get_broadcasted_expr(arr: ArrayOrScalar, def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501 + tags: FrozenSet[Tag], + non_equality_tags: FrozenSet[Optional[Tag]], ) -> ArrayOrScalar: - from pytato.array import (_get_default_axes, _get_default_tags, - _get_created_at_tag) + from pytato.array import _get_default_axes if isinstance(a1, SCALAR_CLASSES): a1 = np.dtype(type(a1)).type(a1) @@ -207,8 +209,8 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, shape=result_shape, dtype=result_dtype, bindings=immutabledict(bindings), - tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + tags=tags, + non_equality_tags=non_equality_tags, var_to_reduction_descr=immutabledict(), axes=_get_default_axes(len(result_shape))) @@ -475,10 +477,13 @@ def _normalized_slice_len(slice_: NormalizedSlice) -> ShapeComponent: # }}} -def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Array: +def _index_into( + ary: Array, + indices: Tuple[ConvertibleToIndexExpr, ...], + tags: FrozenSet[Tag], + non_equality_tags: FrozenSet[Optional[Tag]]) -> Array: from pytato.diagnostic import CannotBroadcastError - from pytato.array import (_get_default_axes, _get_default_tags, - _get_created_at_tag) + from pytato.array import _get_default_axes # {{{ handle ellipsis @@ -564,24 +569,24 @@ def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Arra return AdvancedIndexInNoncontiguousAxes( ary, tuple(normalized_indices), - tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + tags=tags, + non_equality_tags=non_equality_tags, axes=_get_default_axes(len(array_idx_shape) + len(i_basic_indices))) else: return AdvancedIndexInContiguousAxes( ary, tuple(normalized_indices), - tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + tags=tags, + non_equality_tags=non_equality_tags, axes=_get_default_axes(len(array_idx_shape) + len(i_basic_indices))) else: # basic indexing expression return BasicIndex(ary, tuple(normalized_indices), - tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + tags=tags, + non_equality_tags=non_equality_tags, axes=_get_default_axes( len([idx for idx in normalized_indices From ab5728e67e4e5ed0f251bfaadd09fb11d8f74fc7 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Thu, 11 Apr 2024 08:06:31 -0500 Subject: [PATCH 087/178] Disable assert non_equality_tag --- pytato/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index f4b952947..598fa8c7b 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1501,7 +1501,7 @@ class Reshape(IndexRemappingBase): if __debug__: def __attrs_post_init__(self) -> None: - assert self.non_equality_tags + # assert self.non_equality_tags super().__attrs_post_init__() @property From 38e4332d8aea44900061b4e4d5891be33eff064b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 25 Sep 2023 17:36:38 -0500 Subject: [PATCH 088/178] add PytatoKeyBuilder --- pytato/analysis/__init__.py | 29 +++++++++++++++++++++++++++++ test/test_pytato.py | 22 ++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5bf374746..59350329f 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -36,6 +36,7 @@ from pytato.loopy import LoopyCall from pymbolic.mapper.optimize import optimize_mapper from pytools import memoize_method +from loopy.tools import LoopyKeyBuilder, PersistentHashWalkMapper if TYPE_CHECKING: from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder @@ -463,4 +464,32 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: # }}} + +# {{{ PytatoKeyBuilder + +class PytatoKeyBuilder(LoopyKeyBuilder): + """A custom :class:`pytools.persistent_dict.KeyBuilder` subclass + for objects within :mod:`pytato`. + """ + + def update_for_ndarray(self, key_hash: Any, key: Any) -> None: + self.rec(key_hash, hash(key.data.tobytes())) # type: ignore[no-untyped-call] + + def update_for_pymbolic_expression(self, key_hash: Any, key: Any) -> None: + if key is None: + self.update_for_NoneType(key_hash, key) # type: ignore[no-untyped-call] + else: + PersistentHashWalkMapper(key_hash)(key) + + update_for_Product = update_for_pymbolic_expression # noqa: N815 + update_for_Sum = update_for_pymbolic_expression # noqa: N815 + update_for_If = update_for_pymbolic_expression # noqa: N815 + update_for_LogicalOr = update_for_pymbolic_expression # noqa: N815 + update_for_Call = update_for_pymbolic_expression # noqa: N815 + update_for_Comparison = update_for_pymbolic_expression # noqa: N815 + update_for_Quotient = update_for_pymbolic_expression # noqa: N815 + update_for_Power = update_for_pymbolic_expression # noqa: N815 + +# }}} + # vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 8939073cb..e5d8c2b5f 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1231,6 +1231,28 @@ def test_dot_visualizers(): # }}} +def test_persistent_dict(): + from pytools.persistent_dict import WriteOncePersistentDict, ReadOnlyEntryError + from pytato.analysis import PytatoKeyBuilder + + axis_len = 5 + + pd = WriteOncePersistentDict("test_persistent_dict", + key_builder=PytatoKeyBuilder(), + container_dir="./pytest-pdict") + + for i in range(100): + rdagc = RandomDAGContext(np.random.default_rng(seed=i), + axis_len=axis_len, use_numpy=True) + + dag = make_random_dag(rdagc) + pd[dag] = 42 + + # Make sure key stays the same + with pytest.raises(ReadOnlyEntryError): + pd[dag] = 42 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) From 4dd3250a4b7be6f86690c075a69c1ef569ca3137 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 25 Sep 2023 18:04:55 -0500 Subject: [PATCH 089/178] mypy fixes --- pytato/analysis/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 59350329f..c31cb1e5a 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -467,17 +467,17 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: # {{{ PytatoKeyBuilder -class PytatoKeyBuilder(LoopyKeyBuilder): +class PytatoKeyBuilder(LoopyKeyBuilder): # type: ignore[misc] """A custom :class:`pytools.persistent_dict.KeyBuilder` subclass for objects within :mod:`pytato`. """ def update_for_ndarray(self, key_hash: Any, key: Any) -> None: - self.rec(key_hash, hash(key.data.tobytes())) # type: ignore[no-untyped-call] + self.rec(key_hash, hash(key.data.tobytes())) def update_for_pymbolic_expression(self, key_hash: Any, key: Any) -> None: if key is None: - self.update_for_NoneType(key_hash, key) # type: ignore[no-untyped-call] + self.update_for_NoneType(key_hash, key) else: PersistentHashWalkMapper(key_hash)(key) From 970e7bbdc0bd055554bdc38e0853243c2b5440ef Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 28 Sep 2023 16:52:12 -0500 Subject: [PATCH 090/178] support TaggableCLArray, Subscript --- pytato/analysis/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index c31cb1e5a..e5e852143 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -475,6 +475,9 @@ class PytatoKeyBuilder(LoopyKeyBuilder): # type: ignore[misc] def update_for_ndarray(self, key_hash: Any, key: Any) -> None: self.rec(key_hash, hash(key.data.tobytes())) + def update_for_TaggableCLArray(self, key_hash: Any, key: Any) -> None: + self.update_for_ndarray(key_hash, key.get()) + def update_for_pymbolic_expression(self, key_hash: Any, key: Any) -> None: if key is None: self.update_for_NoneType(key_hash, key) @@ -489,6 +492,7 @@ def update_for_pymbolic_expression(self, key_hash: Any, key: Any) -> None: update_for_Comparison = update_for_pymbolic_expression # noqa: N815 update_for_Quotient = update_for_pymbolic_expression # noqa: N815 update_for_Power = update_for_pymbolic_expression # noqa: N815 + update_for_Subscript = update_for_pymbolic_expression # noqa: N815 # }}} From 95dec097a03bae5380fa24541ee11b701faef572 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 28 Sep 2023 18:24:55 -0500 Subject: [PATCH 091/178] CL Array, function --- pytato/analysis/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index e5e852143..9af7b14bb 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -478,6 +478,13 @@ def update_for_ndarray(self, key_hash: Any, key: Any) -> None: def update_for_TaggableCLArray(self, key_hash: Any, key: Any) -> None: self.update_for_ndarray(key_hash, key.get()) + def update_for_Array(self, key_hash: Any, key: Any) -> None: + # CL Array + self.update_for_ndarray(key_hash, key.get()) + + def update_for_function(self, key_hash: Any, key: Any) -> None: + self.rec(key_hash, key.__name__) + def update_for_pymbolic_expression(self, key_hash: Any, key: Any) -> None: if key is None: self.update_for_NoneType(key_hash, key) From 2ac10eea98a6520648a676e8ae089bed5ddb4f2a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 5 Feb 2024 11:13:20 -0600 Subject: [PATCH 092/178] add prim.Variable --- pytato/analysis/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 9af7b14bb..fb0bb1025 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -500,6 +500,7 @@ def update_for_pymbolic_expression(self, key_hash: Any, key: Any) -> None: update_for_Quotient = update_for_pymbolic_expression # noqa: N815 update_for_Power = update_for_pymbolic_expression # noqa: N815 update_for_Subscript = update_for_pymbolic_expression # noqa: N815 + update_for_Variable = update_for_pymbolic_expression # noqa: N815 # }}} From 62a13aeca6214b8848769507e3a32205c9e76fe5 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 5 Feb 2024 12:13:21 -0600 Subject: [PATCH 093/178] fixes to ndarray, pymb expressions --- pytato/analysis/__init__.py | 36 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index fb0bb1025..d74092590 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -473,34 +473,28 @@ class PytatoKeyBuilder(LoopyKeyBuilder): # type: ignore[misc] """ def update_for_ndarray(self, key_hash: Any, key: Any) -> None: - self.rec(key_hash, hash(key.data.tobytes())) + self.rec(key_hash, key.data.tobytes()) def update_for_TaggableCLArray(self, key_hash: Any, key: Any) -> None: - self.update_for_ndarray(key_hash, key.get()) + self.rec(key_hash, key.get()) def update_for_Array(self, key_hash: Any, key: Any) -> None: # CL Array - self.update_for_ndarray(key_hash, key.get()) + self.rec(key_hash, key.get()) def update_for_function(self, key_hash: Any, key: Any) -> None: - self.rec(key_hash, key.__name__) - - def update_for_pymbolic_expression(self, key_hash: Any, key: Any) -> None: - if key is None: - self.update_for_NoneType(key_hash, key) - else: - PersistentHashWalkMapper(key_hash)(key) - - update_for_Product = update_for_pymbolic_expression # noqa: N815 - update_for_Sum = update_for_pymbolic_expression # noqa: N815 - update_for_If = update_for_pymbolic_expression # noqa: N815 - update_for_LogicalOr = update_for_pymbolic_expression # noqa: N815 - update_for_Call = update_for_pymbolic_expression # noqa: N815 - update_for_Comparison = update_for_pymbolic_expression # noqa: N815 - update_for_Quotient = update_for_pymbolic_expression # noqa: N815 - update_for_Power = update_for_pymbolic_expression # noqa: N815 - update_for_Subscript = update_for_pymbolic_expression # noqa: N815 - update_for_Variable = update_for_pymbolic_expression # noqa: N815 + self.rec(key_hash, key.__module__ + key.__qualname__) + + update_for_Product = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_Sum = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_If = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_LogicalOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_Call = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_Comparison = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_Quotient = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_Power = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_Subscript = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_Variable = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 # }}} From b8e04bf00130faa3c17ce6f938f2ca15e7ac85b8 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 5 Feb 2024 12:17:49 -0600 Subject: [PATCH 094/178] flake8 --- pytato/analysis/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index d74092590..41a5842b0 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -36,7 +36,7 @@ from pytato.loopy import LoopyCall from pymbolic.mapper.optimize import optimize_mapper from pytools import memoize_method -from loopy.tools import LoopyKeyBuilder, PersistentHashWalkMapper +from loopy.tools import LoopyKeyBuilder if TYPE_CHECKING: from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder From ad9aa2818d45e7e24f733edc67fdb07d225f569f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 5 Feb 2024 12:25:51 -0600 Subject: [PATCH 095/178] improve test --- test/test_pytato.py | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index e5d8c2b5f..4befdbf7b 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1231,27 +1231,42 @@ def test_dot_visualizers(): # }}} -def test_persistent_dict(): +def test_persistent_hashing_and_persistent_dict(): from pytools.persistent_dict import WriteOncePersistentDict, ReadOnlyEntryError from pytato.analysis import PytatoKeyBuilder + import shutil + import tempfile axis_len = 5 - pd = WriteOncePersistentDict("test_persistent_dict", - key_builder=PytatoKeyBuilder(), - container_dir="./pytest-pdict") + try: + tmpdir = tempfile.mkdtemp() - for i in range(100): - rdagc = RandomDAGContext(np.random.default_rng(seed=i), - axis_len=axis_len, use_numpy=True) + pkb = PytatoKeyBuilder() - dag = make_random_dag(rdagc) - pd[dag] = 42 + pd = WriteOncePersistentDict("test_persistent_dict", + key_builder=pkb, + container_dir=tmpdir) + + for i in range(100): + rdagc = RandomDAGContext(np.random.default_rng(seed=i), + axis_len=axis_len, use_numpy=True) - # Make sure key stays the same - with pytest.raises(ReadOnlyEntryError): + dag = make_random_dag(rdagc) + + # Make sure the PytatoKeyBuilder can handle 'dag' pd[dag] = 42 + # make sure the key stays the same across invocations + if i == 0: + assert pkb(dag) == "eaa8ad49c9490cb6f0b61a33c17d0c2fd10fafc6ce02705105cc9c379c91b9c8" + + # Make sure key stays the same + with pytest.raises(ReadOnlyEntryError): + pd[dag] = 42 + finally: + shutil.rmtree(tmpdir) + if __name__ == "__main__": if len(sys.argv) > 1: From 60d8e41452b36f0321972c5017b032530eff9850 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 5 Feb 2024 13:03:26 -0600 Subject: [PATCH 096/178] add full invocation test --- test/test_pytato.py | 72 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 62 insertions(+), 10 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 4befdbf7b..26af8d746 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1231,14 +1231,42 @@ def test_dot_visualizers(): # }}} -def test_persistent_hashing_and_persistent_dict(): +# {{{ Test PytatoKeyBuilder + +def run_test_with_new_python_invocation(f, *args, extra_env_vars = None) -> None: + import os + if extra_env_vars is None: + extra_env_vars = {} + + from base64 import b64encode + from pickle import dumps + from subprocess import check_call + + env_vars = { + "INVOCATION_INFO": b64encode(dumps((f, args))).decode(), + } + env_vars.update(extra_env_vars) + + my_env = os.environ.copy() + my_env.update(env_vars) + + check_call([sys.executable, __file__], env=my_env) + + +def run_test_with_new_python_invocation_inner() -> None: + from base64 import b64decode + from pickle import loads + f, args = loads(b64decode(os.environ["INVOCATION_INFO"].encode())) + + f(*args) + + +def test_persistent_hashing_and_persistent_dict() -> None: from pytools.persistent_dict import WriteOncePersistentDict, ReadOnlyEntryError from pytato.analysis import PytatoKeyBuilder import shutil import tempfile - axis_len = 5 - try: tmpdir = tempfile.mkdtemp() @@ -1250,26 +1278,50 @@ def test_persistent_hashing_and_persistent_dict(): for i in range(100): rdagc = RandomDAGContext(np.random.default_rng(seed=i), - axis_len=axis_len, use_numpy=True) + axis_len=5, use_numpy=True) dag = make_random_dag(rdagc) # Make sure the PytatoKeyBuilder can handle 'dag' pd[dag] = 42 - # make sure the key stays the same across invocations - if i == 0: - assert pkb(dag) == "eaa8ad49c9490cb6f0b61a33c17d0c2fd10fafc6ce02705105cc9c379c91b9c8" - - # Make sure key stays the same + # Make sure that the key stays the same within the same Python invocation with pytest.raises(ReadOnlyEntryError): pd[dag] = 42 + + # Make sure that the key stays the same across Python invocations + run_test_with_new_python_invocation(_test_persistent_hashing_and_persistent_dict_stage2, + tmpdir) finally: shutil.rmtree(tmpdir) +def _test_persistent_hashing_and_persistent_dict_stage2(tmpdir) -> None: + from pytools.persistent_dict import WriteOncePersistentDict, ReadOnlyEntryError + + from pytato.analysis import PytatoKeyBuilder + pkb = PytatoKeyBuilder() + + pd = WriteOncePersistentDict("test_persistent_dict", + key_builder=pkb, + container_dir=tmpdir) + + for i in range(100): + rdagc = RandomDAGContext(np.random.default_rng(seed=i), + axis_len=5, use_numpy=True) + + dag = make_random_dag(rdagc) + + with pytest.raises(ReadOnlyEntryError): + pd[dag] = 42 + +# }}} + if __name__ == "__main__": - if len(sys.argv) > 1: + import os + if "INVOCATION_INFO" in os.environ: + run_test_with_new_python_invocation_inner() + elif len(sys.argv) > 1: exec(sys.argv[1]) else: from pytest import main From 9d45e6577675c99cec4a0a836f2bbf49e7ad274e Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 5 Feb 2024 13:10:22 -0600 Subject: [PATCH 097/178] lint fixes --- test/test_pytato.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 26af8d746..cb4438eea 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1233,7 +1233,7 @@ def test_dot_visualizers(): # {{{ Test PytatoKeyBuilder -def run_test_with_new_python_invocation(f, *args, extra_env_vars = None) -> None: +def run_test_with_new_python_invocation(f, *args, extra_env_vars=None) -> None: import os if extra_env_vars is None: extra_env_vars = {} @@ -1256,6 +1256,8 @@ def run_test_with_new_python_invocation(f, *args, extra_env_vars = None) -> None def run_test_with_new_python_invocation_inner() -> None: from base64 import b64decode from pickle import loads + import os + f, args = loads(b64decode(os.environ["INVOCATION_INFO"].encode())) f(*args) @@ -1290,11 +1292,12 @@ def test_persistent_hashing_and_persistent_dict() -> None: pd[dag] = 42 # Make sure that the key stays the same across Python invocations - run_test_with_new_python_invocation(_test_persistent_hashing_and_persistent_dict_stage2, - tmpdir) + run_test_with_new_python_invocation( + _test_persistent_hashing_and_persistent_dict_stage2, tmpdir) finally: shutil.rmtree(tmpdir) + def _test_persistent_hashing_and_persistent_dict_stage2(tmpdir) -> None: from pytools.persistent_dict import WriteOncePersistentDict, ReadOnlyEntryError From 08be3806b4f2b153ae4adb083fd38ff63c773bde Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 5 Feb 2024 16:30:43 -0600 Subject: [PATCH 098/178] add missing pymbolic expressions --- pytato/analysis/__init__.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 41a5842b0..ffaba4a8d 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -485,15 +485,27 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: def update_for_function(self, key_hash: Any, key: Any) -> None: self.rec(key_hash, key.__module__ + key.__qualname__) - update_for_Product = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_Sum = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_If = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_LogicalOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_BitwiseAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_BitwiseNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_BitwiseOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_BitwiseXor = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_Call = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_CallWithKwargs = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 update_for_Comparison = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_Quotient = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_If = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_FloorDiv = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_LeftShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_LogicalAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_LogicalNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_LogicalOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_Lookup = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_Power = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_Product = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_Quotient = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_Remainder = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_RightShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_Subscript = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_Sum = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 update_for_Variable = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 # }}} From 058f6f9284afc63ea5fb910b570c39588c572a0b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 5 Feb 2024 22:28:59 -0600 Subject: [PATCH 099/178] flake8 --- pytato/analysis/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index ffaba4a8d..3f9446dfe 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -490,10 +490,10 @@ def update_for_function(self, key_hash: Any, key: Any) -> None: update_for_BitwiseOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_BitwiseXor = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_Call = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_CallWithKwargs = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_CallWithKwargs = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_Comparison = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_If = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_FloorDiv = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_FloorDiv = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_LeftShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_LogicalAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_LogicalNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 From 0360e211272675628acfb67726bbbcb4ab52d4b7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 13 Jun 2024 14:43:04 -0500 Subject: [PATCH 100/178] remove update_for_function (now handled directly by pytools) --- pytato/analysis/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 3f9446dfe..22ab396b7 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -482,9 +482,6 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: # CL Array self.rec(key_hash, key.get()) - def update_for_function(self, key_hash: Any, key: Any) -> None: - self.rec(key_hash, key.__module__ + key.__qualname__) - update_for_BitwiseAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_BitwiseNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_BitwiseOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 From 3364a4fda9cc11d46c2ebe8e32fee0b161da27cf Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 16:19:06 -0500 Subject: [PATCH 101/178] working pass 1 --- pytato/analysis/__init__.py | 50 +++++++++---------- pytato/distributed/partition.py | 86 ++++++++++++++++++--------------- pytato/transform/__init__.py | 5 +- 3 files changed, 74 insertions(+), 67 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 38ed276fe..1a4359e4b 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -310,48 +310,48 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter +from orderedsets import FrozenOrderedSet + class DirectPredecessorsGetter(Mapper): """ Mapper to get the `direct predecessors `__ of a node. - .. note:: - We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[Array]: - return frozenset({dim for dim in shape if isinstance(dim, Array)}) + def _get_preds_from_shape(self, shape: ShapeType) -> abc_Set[Array]: + return FrozenOrderedSet([dim for dim in shape if isinstance(dim, Array)]) - def map_index_lambda(self, expr: IndexLambda) -> frozenset[Array]: - return (frozenset(expr.bindings.values()) + def map_index_lambda(self, expr: IndexLambda) -> abc_Set[Array]: + return (FrozenOrderedSet(expr.bindings.values()) | self._get_preds_from_shape(expr.shape)) - def map_stack(self, expr: Stack) -> frozenset[Array]: - return (frozenset(expr.arrays) + def map_stack(self, expr: Stack) -> abc_Set[Array]: + return (FrozenOrderedSet(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_concatenate(self, expr: Concatenate) -> frozenset[Array]: - return (frozenset(expr.arrays) + def map_concatenate(self, expr: Concatenate) -> abc_Set[Array]: + return (FrozenOrderedSet(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_einsum(self, expr: Einsum) -> frozenset[Array]: - return (frozenset(expr.args) + def map_einsum(self, expr: Einsum) -> abc_Set[Array]: + return (FrozenOrderedSet(expr.args) | self._get_preds_from_shape(expr.shape)) - def map_loopy_call_result(self, expr: NamedArray) -> frozenset[Array]: - from pytato.loopy import LoopyCall, LoopyCallResult + def map_loopy_call_result(self, expr: NamedArray) -> abc_Set[Array]: + from pytato.loopy import LoopyCallResult, LoopyCall assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) - return (frozenset(ary + return (FrozenOrderedSet(ary for ary in expr._container.bindings.values() if isinstance(ary, Array)) | self._get_preds_from_shape(expr.shape)) - def _map_index_base(self, expr: IndexBase) -> frozenset[Array]: - return (frozenset([expr.array]) - | frozenset(idx for idx in expr.indices + def _map_index_base(self, expr: IndexBase) -> abc_Set[Array]: + return (FrozenOrderedSet([expr.array]) + | FrozenOrderedSet(idx for idx in expr.indices if isinstance(idx, Array)) | self._get_preds_from_shape(expr.shape)) @@ -360,29 +360,29 @@ def _map_index_base(self, expr: IndexBase) -> frozenset[Array]: map_non_contiguous_advanced_index = _map_index_base def _map_index_remapping_base(self, expr: IndexRemappingBase - ) -> frozenset[Array]: - return frozenset([expr.array]) + ) -> abc_Set[Array]: + return FrozenOrderedSet([expr.array]) map_roll = _map_index_remapping_base map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> frozenset[Array]: + def _map_input_base(self, expr: InputArgumentBase) -> abc_Set[Array]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[Array]: + def map_distributed_recv(self, expr: DistributedRecv) -> abc_Set[Array]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> frozenset[Array]: - return frozenset([expr.passthrough_data]) + ) -> abc_Set[Array]: + return FrozenOrderedSet([expr.passthrough_data]) - def map_named_call_result(self, expr: NamedCallResult) -> frozenset[Array]: + def map_named_call_result(self, expr: NamedCallResult) -> abc_Set[Array]: raise NotImplementedError( "DirectPredecessorsGetter does not yet support expressions containing " "functions.") diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 5865ec491..7c9b510e7 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -476,9 +476,10 @@ def __init__(self, local_rank: int) -> None: self.local_rank = local_rank def combine( - self, *args: frozenset[CommunicationOpIdentifier] - ) -> frozenset[CommunicationOpIdentifier]: - return reduce(frozenset.union, args, frozenset()) + self, *args: Tuple[CommunicationOpIdentifier] + ) -> Tuple[CommunicationOpIdentifier]: + from pytools import unique + return reduce(lambda x, y: tuple(unique(x+y)), args, tuple()) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder @@ -496,8 +497,8 @@ def map_distributed_send_ref_holder(self, return self.rec(expr.passthrough_data) - def _map_input_base(self, expr: Array) -> frozenset[CommunicationOpIdentifier]: - return frozenset() + def _map_input_base(self, expr: Array) -> Tuple[CommunicationOpIdentifier]: + return tuple() map_placeholder = _map_input_base map_data_wrapper = _map_input_base @@ -505,21 +506,21 @@ def _map_input_base(self, expr: Array) -> frozenset[CommunicationOpIdentifier]: def map_distributed_recv( self, expr: DistributedRecv - ) -> frozenset[CommunicationOpIdentifier]: + ) -> Tuple[CommunicationOpIdentifier]: recv_id = _recv_to_comm_id(self.local_rank, expr) if recv_id in self.local_recv_id_to_recv_node: from pytato.distributed.verify import DuplicateRecvError raise DuplicateRecvError(f"Multiple receives found for '{recv_id}'") - self.local_comm_ids_to_needed_comm_ids[recv_id] = frozenset() + self.local_comm_ids_to_needed_comm_ids[recv_id] = tuple() self.local_recv_id_to_recv_node[recv_id] = expr - return frozenset({recv_id}) + return (recv_id,) def map_named_call_result( - self, expr: NamedCallResult) -> frozenset[CommunicationOpIdentifier]: + self, expr: NamedCallResult) -> Tuple[CommunicationOpIdentifier]: raise NotImplementedError( "LocalSendRecvDepGatherer does not support functions.") @@ -557,10 +558,10 @@ def _schedule_task_batches_counted( task_to_dep_level, visits_in_depend = \ _calculate_dependency_levels(task_ids_to_needed_task_ids) nlevels = 1 + max(task_to_dep_level.values(), default=-1) - task_batches: Sequence[set[TaskType]] = [set() for _ in range(nlevels)] + task_batches: Sequence[List[TaskType]] = [list() for _ in range(nlevels)] for task_id, dep_level in task_to_dep_level.items(): - task_batches[dep_level].add(task_id) + task_batches[dep_level].append(task_id) return task_batches, visits_in_depend + len(task_to_dep_level.keys()) @@ -623,7 +624,7 @@ class _MaterializedArrayCollector(CachedWalkMapper): """ def __init__(self) -> None: super().__init__() - self.materialized_arrays: _OrderedSet[Array] = _OrderedSet() + self.materialized_arrays: List[Array] = [] def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) @@ -633,15 +634,15 @@ def post_visit(self, expr: Any) -> None: from pytato.tags import ImplStored if (isinstance(expr, Array) and expr.tags_of_type(ImplStored)): - self.materialized_arrays.add(expr) + self.materialized_arrays.append(expr) if isinstance(expr, LoopyCallResult): - self.materialized_arrays.add(expr) + self.materialized_arrays.append(expr) from pytato.loopy import LoopyCall assert isinstance(expr._container, LoopyCall) for _, subexpr in sorted(expr._container.bindings.items()): if isinstance(subexpr, Array): - self.materialized_arrays.add(subexpr) + self.materialized_arrays.append(subexpr) else: assert isinstance(subexpr, SCALAR_CLASSES) @@ -651,13 +652,13 @@ def post_visit(self, expr: Any) -> None: # {{{ _set_dict_union_mpi def _set_dict_union_mpi( - dict_a: Mapping[_KeyT, frozenset[_ValueT]], - dict_b: Mapping[_KeyT, frozenset[_ValueT]], - mpi_data_type: mpi4py.MPI.Datatype) -> Mapping[_KeyT, frozenset[_ValueT]]: + dict_a: Mapping[_KeyT, Sequence[_ValueT]], + dict_b: Mapping[_KeyT, Sequence[_ValueT]], + mpi_data_type: mpi4py.MPI.Datatype) -> Mapping[_KeyT, Sequence[_ValueT]]: assert mpi_data_type is None result = dict(dict_a) for key, values in dict_b.items(): - result[key] = result.get(key, frozenset()) | values + result[key] = result.get(key, tuple()) + values return result # }}} @@ -782,6 +783,8 @@ def find_distributed_partition( - Gather sent arrays into assigned in :attr:`DistributedGraphPart.name_to_send_nodes`. """ + from pytools import unique + import mpi4py.MPI as MPI from pytato.transform import SubsetDependencyMapper @@ -833,12 +836,13 @@ def find_distributed_partition( # {{{ create (local) parts out of batch ids - part_comm_ids: list[_PartCommIDs] = [] + + part_comm_ids: List[_PartCommIDs] = [] if comm_batches: - recv_ids: frozenset[CommunicationOpIdentifier] = frozenset() + recv_ids: Tuple[CommunicationOpIdentifier] = tuple() for batch in comm_batches: - send_ids = frozenset( - comm_id for comm_id in batch + send_ids = tuple( + comm_id for comm_id in unique(batch) if comm_id.src_rank == local_rank) if recv_ids or send_ids: part_comm_ids.append( @@ -846,19 +850,19 @@ def find_distributed_partition( recv_ids=recv_ids, send_ids=send_ids)) # These go into the next part - recv_ids = frozenset( - comm_id for comm_id in batch + recv_ids = tuple( + comm_id for comm_id in unique(batch) if comm_id.dest_rank == local_rank) if recv_ids: part_comm_ids.append( _PartCommIDs( recv_ids=recv_ids, - send_ids=frozenset())) + send_ids=tuple())) else: part_comm_ids.append( _PartCommIDs( - recv_ids=frozenset(), - send_ids=frozenset())) + recv_ids=tuple(), + send_ids=tuple())) nparts = len(part_comm_ids) @@ -876,7 +880,7 @@ def find_distributed_partition( comm_id_to_part_id = { comm_id: ipart for ipart, comm_ids in enumerate(part_comm_ids) - for comm_id in comm_ids.send_ids | comm_ids.recv_ids} + for comm_id in unique(comm_ids.send_ids + comm_ids.recv_ids)} # }}} @@ -888,10 +892,10 @@ def find_distributed_partition( # The sets of arrays below must have a deterministic order in order to ensure # that the resulting partition is also deterministic - sent_arrays = _OrderedSet( + sent_arrays = tuple( send_node.data for send_node in lsrdg.local_send_id_to_send_node.values()) - received_arrays = _OrderedSet(lsrdg.local_recv_id_to_recv_node.values()) + received_arrays = tuple(lsrdg.local_recv_id_to_recv_node.values()) # While receive nodes may be marked as materialized, we shouldn't be # including them here because we're using them (along with the send nodes) @@ -899,14 +903,16 @@ def find_distributed_partition( # We could allow sent *arrays* to be included here because they are distinct # from send *nodes*, but we choose to exclude them in order to simplify the # processing below. - materialized_arrays = ( - materialized_arrays_collector.materialized_arrays - - received_arrays - - sent_arrays) + materialized_arrays_set = set(materialized_arrays_collector.materialized_arrays) \ + - set(received_arrays) \ + - set(sent_arrays) + + from pytools import unique + materialized_arrays = tuple(a for a in materialized_arrays_collector.materialized_arrays if a in materialized_arrays_set) # "mso" for "materialized/sent/output" - output_arrays = _OrderedSet(outputs._data.values()) - mso_arrays = materialized_arrays | sent_arrays | output_arrays + output_arrays = tuple(outputs._data.values()) + mso_arrays = materialized_arrays + sent_arrays + output_arrays # FIXME: This gathers up materialized_arrays recursively, leading to # result sizes potentially quadratic in the number of materialized arrays. @@ -970,7 +976,7 @@ def find_distributed_partition( assert all(0 <= part_id < nparts for part_id in stored_ary_to_part_id.values()) - stored_arrays = _OrderedSet(stored_ary_to_part_id) + stored_arrays = tuple(unique(stored_ary_to_part_id)) # {{{ find which stored arrays should become part outputs # (because they are used in not just their local part, but also others) @@ -986,13 +992,13 @@ def get_materialized_predecessors(ary: Array) -> _OrderedSet[Array]: materialized_preds |= get_materialized_predecessors(pred) return materialized_preds - stored_arrays_promoted_to_part_outputs = { + stored_arrays_promoted_to_part_outputs = tuple(unique( stored_pred for stored_ary in stored_arrays for stored_pred in get_materialized_predecessors(stored_ary) if (stored_ary_to_part_id[stored_ary] != stored_ary_to_part_id[stored_pred]) - } + )) # }}} diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index b78c24301..56b2a53d6 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -926,9 +926,10 @@ def __init__(self, universe: frozenset[Array]): def combine(self, *args: frozenset[Array]) -> frozenset[Array]: from functools import reduce - return reduce(lambda acc, arg: acc | (arg & self.universe), + from pytools import unique + return reduce(lambda acc, arg: unique(tuple(acc) + tuple(set(arg) & self.universe)), args, - frozenset()) + tuple()) # }}} From ef7ea0bb74e1543b317558017f03f7be6352e5c1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 16:33:06 -0500 Subject: [PATCH 102/178] cleanups --- pytato/analysis/__init__.py | 6 ++++- pytato/distributed/partition.py | 43 ++++++++++++++++++--------------- pytato/transform/__init__.py | 6 +++-- setup.py | 1 + 4 files changed, 33 insertions(+), 23 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 1a4359e4b..d7a8e3353 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -310,14 +310,18 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter +from collections.abc import Set as abc_Set + from orderedsets import FrozenOrderedSet + class DirectPredecessorsGetter(Mapper): """ Mapper to get the `direct predecessors `__ of a node. + .. note:: We only consider the predecessors of a nodes in a data-flow sense. """ @@ -341,7 +345,7 @@ def map_einsum(self, expr: Einsum) -> abc_Set[Array]: | self._get_preds_from_shape(expr.shape)) def map_loopy_call_result(self, expr: NamedArray) -> abc_Set[Array]: - from pytato.loopy import LoopyCallResult, LoopyCall + from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) return (FrozenOrderedSet(ary diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 7c9b510e7..2d4b1c93a 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -476,10 +476,10 @@ def __init__(self, local_rank: int) -> None: self.local_rank = local_rank def combine( - self, *args: Tuple[CommunicationOpIdentifier] - ) -> Tuple[CommunicationOpIdentifier]: + self, *args: tuple[CommunicationOpIdentifier] + ) -> tuple[CommunicationOpIdentifier]: from pytools import unique - return reduce(lambda x, y: tuple(unique(x+y)), args, tuple()) + return reduce(lambda x, y: tuple(unique(x+y)), args, ()) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder @@ -497,8 +497,8 @@ def map_distributed_send_ref_holder(self, return self.rec(expr.passthrough_data) - def _map_input_base(self, expr: Array) -> Tuple[CommunicationOpIdentifier]: - return tuple() + def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier]: + return () map_placeholder = _map_input_base map_data_wrapper = _map_input_base @@ -506,21 +506,21 @@ def _map_input_base(self, expr: Array) -> Tuple[CommunicationOpIdentifier]: def map_distributed_recv( self, expr: DistributedRecv - ) -> Tuple[CommunicationOpIdentifier]: + ) -> tuple[CommunicationOpIdentifier]: recv_id = _recv_to_comm_id(self.local_rank, expr) if recv_id in self.local_recv_id_to_recv_node: from pytato.distributed.verify import DuplicateRecvError raise DuplicateRecvError(f"Multiple receives found for '{recv_id}'") - self.local_comm_ids_to_needed_comm_ids[recv_id] = tuple() + self.local_comm_ids_to_needed_comm_ids[recv_id] = () self.local_recv_id_to_recv_node[recv_id] = expr return (recv_id,) def map_named_call_result( - self, expr: NamedCallResult) -> Tuple[CommunicationOpIdentifier]: + self, expr: NamedCallResult) -> tuple[CommunicationOpIdentifier]: raise NotImplementedError( "LocalSendRecvDepGatherer does not support functions.") @@ -558,7 +558,7 @@ def _schedule_task_batches_counted( task_to_dep_level, visits_in_depend = \ _calculate_dependency_levels(task_ids_to_needed_task_ids) nlevels = 1 + max(task_to_dep_level.values(), default=-1) - task_batches: Sequence[List[TaskType]] = [list() for _ in range(nlevels)] + task_batches: Sequence[list[TaskType]] = [[] for _ in range(nlevels)] for task_id, dep_level in task_to_dep_level.items(): task_batches[dep_level].append(task_id) @@ -624,7 +624,7 @@ class _MaterializedArrayCollector(CachedWalkMapper): """ def __init__(self) -> None: super().__init__() - self.materialized_arrays: List[Array] = [] + self.materialized_arrays: list[Array] = [] def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) @@ -658,7 +658,7 @@ def _set_dict_union_mpi( assert mpi_data_type is None result = dict(dict_a) for key, values in dict_b.items(): - result[key] = result.get(key, tuple()) + values + result[key] = result.get(key, ()) + values return result # }}} @@ -783,10 +783,10 @@ def find_distributed_partition( - Gather sent arrays into assigned in :attr:`DistributedGraphPart.name_to_send_nodes`. """ - from pytools import unique - import mpi4py.MPI as MPI + from pytools import unique + from pytato.transform import SubsetDependencyMapper local_rank = mpi_communicator.rank @@ -837,9 +837,9 @@ def find_distributed_partition( # {{{ create (local) parts out of batch ids - part_comm_ids: List[_PartCommIDs] = [] + part_comm_ids: list[_PartCommIDs] = [] if comm_batches: - recv_ids: Tuple[CommunicationOpIdentifier] = tuple() + recv_ids: tuple[CommunicationOpIdentifier] = () for batch in comm_batches: send_ids = tuple( comm_id for comm_id in unique(batch) @@ -857,12 +857,12 @@ def find_distributed_partition( part_comm_ids.append( _PartCommIDs( recv_ids=recv_ids, - send_ids=tuple())) + send_ids=())) else: part_comm_ids.append( _PartCommIDs( - recv_ids=tuple(), - send_ids=tuple())) + recv_ids=(), + send_ids=())) nparts = len(part_comm_ids) @@ -908,7 +908,9 @@ def find_distributed_partition( - set(sent_arrays) from pytools import unique - materialized_arrays = tuple(a for a in materialized_arrays_collector.materialized_arrays if a in materialized_arrays_set) + materialized_arrays = tuple( + a for a in materialized_arrays_collector.materialized_arrays + if a in materialized_arrays_set) # "mso" for "materialized/sent/output" output_arrays = tuple(outputs._data.values()) @@ -927,7 +929,8 @@ def find_distributed_partition( comm_id_to_part_id[send_id]) if __debug__: - recvd_array_dep_mapper = SubsetDependencyMapper(frozenset(received_arrays)) + recvd_array_dep_mapper = SubsetDependencyMapper(frozenset + (received_arrays)) mso_ary_to_last_dep_recv_part_id: dict[Array, int] = { ary: max( diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 56b2a53d6..642f52839 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -926,10 +926,12 @@ def __init__(self, universe: frozenset[Array]): def combine(self, *args: frozenset[Array]) -> frozenset[Array]: from functools import reduce + from pytools import unique - return reduce(lambda acc, arg: unique(tuple(acc) + tuple(set(arg) & self.universe)), + return reduce(lambda acc, arg: + unique(tuple(acc) + tuple(set(arg) & self.universe)), args, - tuple()) + ()) # }}} diff --git a/setup.py b/setup.py index ba0bd1b4d..9fe0df6b1 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ "immutabledict", "attrs", "bidict", + "orderedsets", ], package_data={"pytato": ["py.typed"]}, author="Andreas Kloeckner, Matt Wala, Xiaoyu Wei", From 817b255ca54f3232bc01fbd9c11ece34b54b7cac Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 16:44:33 -0500 Subject: [PATCH 103/178] enable determinism test --- test/test_distributed.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/test_distributed.py b/test/test_distributed.py index ac7ca1389..1554a024b 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -899,13 +899,11 @@ def test_number_symbolic_tags_bare_classes(ctx_factory): outputs = pt.make_dict_of_named_arrays({"out": res}) partition = pt.find_distributed_partition(comm, outputs) - (_distp, next_tag) = pt.number_distributed_tags(comm, partition, base_tag=4242) + (distp, next_tag) = pt.number_distributed_tags(comm, partition, base_tag=4242) assert next_tag == 4244 - # FIXME: For the next assertion, find_distributed_partition needs to be - # deterministic too (https://github.com/inducer/pytato/pull/465). - # assert next(iter(distp.parts[0].name_to_send_nodes.values()))[0].comm_tag == 4242 # noqa: E501 + assert next(iter(distp.parts[0].name_to_send_nodes.values()))[0].comm_tag == 4242 # }}} From f3f3c7df968088a05a2792e189d43a70c14e34c5 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 16:54:43 -0500 Subject: [PATCH 104/178] eliminate _OrderedSets --- pytato/distributed/partition.py | 70 +++------------------------------ 1 file changed, 6 insertions(+), 64 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 2d4b1c93a..6d3adb319 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -62,7 +62,6 @@ THE SOFTWARE. """ -import collections from functools import reduce from typing import ( TYPE_CHECKING, @@ -70,8 +69,6 @@ Any, FrozenSet, Hashable, - Iterable, - Iterator, Mapping, Sequence, TypeVar, @@ -131,61 +128,6 @@ class CommunicationOpIdentifier: _ValueT = TypeVar("_ValueT") -# {{{ crude ordered set - - -class _OrderedSet(collections.abc.MutableSet[_ValueT]): - def __init__(self, items: Iterable[_ValueT] | None = None): - # Could probably also use a valueless dictionary; not sure if it matters - self._items: set[_ValueT] = set() - self._items_ordered: list[_ValueT] = [] - if items is not None: - for item in items: - self.add(item) - - def add(self, item: _ValueT) -> None: - if item not in self._items: - self._items.add(item) - self._items_ordered.append(item) - - def discard(self, item: _ValueT) -> None: - # Not currently needed - raise NotImplementedError - - def __len__(self) -> int: - return len(self._items) - - def __iter__(self) -> Iterator[_ValueT]: - return iter(self._items_ordered) - - def __contains__(self, item: Any) -> bool: - return item in self._items - - def __and__(self, other: AbstractSet[_ValueT]) -> _OrderedSet[_ValueT]: - result: _OrderedSet[_ValueT] = _OrderedSet() - for item in self._items_ordered: - if item in other: - result.add(item) - return result - - # Must be "Any" instead of "_ValueT", otherwise it violates Liskov substitution - # according to mypy. *shrug* - def __or__(self, other: AbstractSet[Any]) -> _OrderedSet[_ValueT]: - result: _OrderedSet[_ValueT] = _OrderedSet(self._items_ordered) - for item in other: - result.add(item) - return result - - def __sub__(self, other: AbstractSet[_ValueT]) -> _OrderedSet[_ValueT]: - result: _OrderedSet[_ValueT] = _OrderedSet() - for item in self._items_ordered: - if item not in other: - result.add(item) - return result - -# }}} - - # {{{ distributed graph part PartId = Hashable @@ -836,7 +778,6 @@ def find_distributed_partition( # {{{ create (local) parts out of batch ids - part_comm_ids: list[_PartCommIDs] = [] if comm_batches: recv_ids: tuple[CommunicationOpIdentifier] = () @@ -986,14 +927,15 @@ def find_distributed_partition( direct_preds_getter = DirectPredecessorsGetter() - def get_materialized_predecessors(ary: Array) -> _OrderedSet[Array]: - materialized_preds: _OrderedSet[Array] = _OrderedSet() + def get_materialized_predecessors(ary: Array) -> tuple[Array]: + materialized_preds: dict[Array, None] = {} for pred in direct_preds_getter(ary): if pred in materialized_arrays: - materialized_preds.add(pred) + materialized_preds[pred] = None else: - materialized_preds |= get_materialized_predecessors(pred) - return materialized_preds + for p in get_materialized_predecessors(pred): + materialized_preds[p] = None + return tuple(materialized_preds.keys()) stored_arrays_promoted_to_part_outputs = tuple(unique( stored_pred From 8bf2daf69260673b03ed8ef6fd2c03d01483eb04 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 17:29:07 -0500 Subject: [PATCH 105/178] misc improvements --- pytato/distributed/partition.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 6d3adb319..81a493c9a 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -503,7 +503,8 @@ def _schedule_task_batches_counted( task_batches: Sequence[list[TaskType]] = [[] for _ in range(nlevels)] for task_id, dep_level in task_to_dep_level.items(): - task_batches[dep_level].append(task_id) + if task_id not in task_batches[dep_level]: + task_batches[dep_level].append(task_id) return task_batches, visits_in_depend + len(task_to_dep_level.keys()) @@ -566,7 +567,7 @@ class _MaterializedArrayCollector(CachedWalkMapper): """ def __init__(self) -> None: super().__init__() - self.materialized_arrays: list[Array] = [] + self.materialized_arrays: dict[Array, None] = {} def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) @@ -576,15 +577,15 @@ def post_visit(self, expr: Any) -> None: from pytato.tags import ImplStored if (isinstance(expr, Array) and expr.tags_of_type(ImplStored)): - self.materialized_arrays.append(expr) + self.materialized_arrays[expr] = None if isinstance(expr, LoopyCallResult): - self.materialized_arrays.append(expr) + self.materialized_arrays[expr] = None from pytato.loopy import LoopyCall assert isinstance(expr._container, LoopyCall) for _, subexpr in sorted(expr._container.bindings.items()): if isinstance(subexpr, Array): - self.materialized_arrays.append(subexpr) + self.materialized_arrays[subexpr] = None else: assert isinstance(subexpr, SCALAR_CLASSES) @@ -596,11 +597,12 @@ def post_visit(self, expr: Any) -> None: def _set_dict_union_mpi( dict_a: Mapping[_KeyT, Sequence[_ValueT]], dict_b: Mapping[_KeyT, Sequence[_ValueT]], - mpi_data_type: mpi4py.MPI.Datatype) -> Mapping[_KeyT, Sequence[_ValueT]]: + mpi_data_type: mpi4py.MPI.Datatype | None) -> Mapping[_KeyT, Sequence[_ValueT]]: assert mpi_data_type is None + from pytools import unique result = dict(dict_a) for key, values in dict_b.items(): - result[key] = result.get(key, ()) + values + result[key] = tuple(unique(result.get(key, ()) + values)) return result # }}} @@ -833,10 +835,10 @@ def find_distributed_partition( # The sets of arrays below must have a deterministic order in order to ensure # that the resulting partition is also deterministic - sent_arrays = tuple( - send_node.data for send_node in lsrdg.local_send_id_to_send_node.values()) + sent_arrays = tuple(unique( + send_node.data for send_node in lsrdg.local_send_id_to_send_node.values())) - received_arrays = tuple(lsrdg.local_recv_id_to_recv_node.values()) + received_arrays = tuple(unique(lsrdg.local_recv_id_to_recv_node.values())) # While receive nodes may be marked as materialized, we shouldn't be # including them here because we're using them (along with the send nodes) @@ -849,13 +851,13 @@ def find_distributed_partition( - set(sent_arrays) from pytools import unique - materialized_arrays = tuple( + materialized_arrays = tuple(unique( a for a in materialized_arrays_collector.materialized_arrays - if a in materialized_arrays_set) + if a in materialized_arrays_set)) # "mso" for "materialized/sent/output" - output_arrays = tuple(outputs._data.values()) - mso_arrays = materialized_arrays + sent_arrays + output_arrays + output_arrays = tuple(unique(outputs._data.values())) + mso_arrays = tuple(unique(materialized_arrays + sent_arrays + output_arrays)) # FIXME: This gathers up materialized_arrays recursively, leading to # result sizes potentially quadratic in the number of materialized arrays. @@ -870,8 +872,7 @@ def find_distributed_partition( comm_id_to_part_id[send_id]) if __debug__: - recvd_array_dep_mapper = SubsetDependencyMapper(frozenset - (received_arrays)) + recvd_array_dep_mapper = SubsetDependencyMapper(frozenset(received_arrays)) mso_ary_to_last_dep_recv_part_id: dict[Array, int] = { ary: max( From 5d906b5edb98425f48bcc3aada14725f0c9223b1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 17:41:48 -0500 Subject: [PATCH 106/178] revert change to SubsetDependencyMapper --- pytato/transform/__init__.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 642f52839..b78c24301 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -926,12 +926,9 @@ def __init__(self, universe: frozenset[Array]): def combine(self, *args: frozenset[Array]) -> frozenset[Array]: from functools import reduce - - from pytools import unique - return reduce(lambda acc, arg: - unique(tuple(acc) + tuple(set(arg) & self.universe)), + return reduce(lambda acc, arg: acc | (arg & self.universe), args, - ()) + frozenset()) # }}} From 142c8e63cf4ca9d5c570fb4e99a3d86594770bc8 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 17:50:03 -0500 Subject: [PATCH 107/178] some mypy fixes --- pytato/distributed/partition.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 81a493c9a..e8f4b1fb2 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -67,7 +67,6 @@ TYPE_CHECKING, AbstractSet, Any, - FrozenSet, Hashable, Mapping, Sequence, @@ -316,8 +315,8 @@ def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: class _PartCommIDs: """A *part*, unlike a *batch*, begins with receives and ends with sends. """ - recv_ids: frozenset[CommunicationOpIdentifier] - send_ids: frozenset[CommunicationOpIdentifier] + recv_ids: tuple[CommunicationOpIdentifier] + send_ids: tuple[CommunicationOpIdentifier] # {{{ _make_distributed_partition @@ -403,12 +402,12 @@ def _recv_to_comm_id( class _LocalSendRecvDepGatherer( - CombineMapper[FrozenSet[CommunicationOpIdentifier]]): + CombineMapper[tuple[CommunicationOpIdentifier]]): def __init__(self, local_rank: int) -> None: super().__init__() self.local_comm_ids_to_needed_comm_ids: \ dict[CommunicationOpIdentifier, - frozenset[CommunicationOpIdentifier]] = {} + tuple[CommunicationOpIdentifier]] = {} self.local_recv_id_to_recv_node: \ dict[CommunicationOpIdentifier, DistributedRecv] = {} @@ -425,7 +424,7 @@ def combine( def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> frozenset[CommunicationOpIdentifier]: + ) -> tuple[CommunicationOpIdentifier]: send_id = _send_to_comm_id(self.local_rank, expr.send) if send_id in self.local_send_id_to_send_node: @@ -476,7 +475,7 @@ def map_named_call_result( def _schedule_task_batches( task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ - -> Sequence[AbstractSet[TaskType]]: + -> Sequence[list[TaskType]]: """For each :type:`TaskType`, determine the 'round'/'batch' during which it will be performed. A 'batch' of tasks consists of tasks which do not depend on each other. @@ -491,7 +490,7 @@ def _schedule_task_batches( def _schedule_task_batches_counted( task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ - -> tuple[Sequence[AbstractSet[TaskType]], int]: + -> tuple[Sequence[list[TaskType]], int]: """ Static type checkers need the functions to return the same type regardless of the input. The testing code needs to know about the number of tasks visited @@ -773,7 +772,7 @@ def find_distributed_partition( raise comm_batches_or_exc comm_batches = cast( - Sequence[AbstractSet[CommunicationOpIdentifier]], + Sequence[list[CommunicationOpIdentifier]], comm_batches_or_exc) # }}} @@ -928,7 +927,7 @@ def find_distributed_partition( direct_preds_getter = DirectPredecessorsGetter() - def get_materialized_predecessors(ary: Array) -> tuple[Array]: + def get_materialized_predecessors(ary: Array) -> tuple[Array, ...]: materialized_preds: dict[Array, None] = {} for pred in direct_preds_getter(ary): if pred in materialized_arrays: From bd7062071e283e7b063da74ca3f3d15c51986854 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 12 Aug 2024 16:12:04 -0700 Subject: [PATCH 108/178] ruff --- pytato/reductions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/reductions.py b/pytato/reductions.py index 999ef1af4..b0f2b7fb2 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -178,9 +178,9 @@ def _normalize_reduction_axes( raise ValueError(f"{axis} is out of bounds for array of dimension" f" {len(shape)}.") - new_shape = tuple([axis_len + new_shape = tuple(axis_len for i, axis_len in enumerate(shape) - if i not in reduction_axes]) + if i not in reduction_axes) return new_shape, reduction_axes From 076a76ebe152f8d82c47de31a05f25586b981e4f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 13 Aug 2024 13:51:12 -0500 Subject: [PATCH 109/178] replace orderedsets with unique tuples in DirectPredecessorsGetter --- pytato/analysis/__init__.py | 70 +++++++++++++++++-------------------- setup.py | 1 - 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index fa8ac31e7..883030a43 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any, Mapping from pymbolic.mapper.optimize import optimize_mapper -from pytools import memoize_method +from pytools import memoize_method, unique from pytato.array import ( Array, @@ -314,11 +314,6 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter -from collections.abc import Set as abc_Set - -from orderedsets import FrozenOrderedSet - - class DirectPredecessorsGetter(Mapper): """ Mapper to get the @@ -327,74 +322,75 @@ class DirectPredecessorsGetter(Mapper): of a node. .. note:: + We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> abc_Set[ArrayOrNames]: - return FrozenOrderedSet([dim for dim in shape if isinstance(dim, Array)]) + def _get_preds_from_shape(self, shape: ShapeType) -> tuple[ArrayOrNames]: + return tuple(unique(dim for dim in shape if isinstance(dim, Array))) - def map_index_lambda(self, expr: IndexLambda) -> abc_Set[ArrayOrNames]: - return (FrozenOrderedSet(expr.bindings.values()) - | self._get_preds_from_shape(expr.shape)) + def map_index_lambda(self, expr: IndexLambda) -> tuple[ArrayOrNames]: + return tuple(unique(tuple(expr.bindings.values()) + + self._get_preds_from_shape(expr.shape))) - def map_stack(self, expr: Stack) -> abc_Set[ArrayOrNames]: - return (FrozenOrderedSet(expr.arrays) - | self._get_preds_from_shape(expr.shape)) + def map_stack(self, expr: Stack) -> tuple[ArrayOrNames]: + return tuple(unique(tuple(expr.arrays) + + self._get_preds_from_shape(expr.shape))) - def map_concatenate(self, expr: Concatenate) -> abc_Set[ArrayOrNames]: - return (FrozenOrderedSet(expr.arrays) - | self._get_preds_from_shape(expr.shape)) + map_concatenate = map_stack - def map_einsum(self, expr: Einsum) -> abc_Set[ArrayOrNames]: - return (FrozenOrderedSet(expr.args) - | self._get_preds_from_shape(expr.shape)) + def map_einsum(self, expr: Einsum) -> tuple[ArrayOrNames]: + return tuple(unique(tuple(expr.args) + + self._get_preds_from_shape(expr.shape))) - def map_loopy_call_result(self, expr: NamedArray) -> abc_Set[Array]: + def map_loopy_call_result(self, expr: NamedArray) -> tuple[ArrayOrNames]: from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) - return (FrozenOrderedSet(ary + return tuple(unique(tuple(ary for ary in expr._container.bindings.values() if isinstance(ary, Array)) - | self._get_preds_from_shape(expr.shape)) + + self._get_preds_from_shape(expr.shape))) - def _map_index_base(self, expr: IndexBase) -> abc_Set[ArrayOrNames]: - return (FrozenOrderedSet([expr.array]) - | FrozenOrderedSet(idx for idx in expr.indices + def _map_index_base(self, expr: IndexBase) -> tuple[ArrayOrNames]: + return tuple(unique((expr.array,) # noqa: RUF005 + + tuple(idx for idx in expr.indices if isinstance(idx, Array)) - | self._get_preds_from_shape(expr.shape)) + + self._get_preds_from_shape(expr.shape))) map_basic_index = _map_index_base map_contiguous_advanced_index = _map_index_base map_non_contiguous_advanced_index = _map_index_base def _map_index_remapping_base(self, expr: IndexRemappingBase - ) -> abc_Set[ArrayOrNames]: - return FrozenOrderedSet([expr.array]) + ) -> tuple[ArrayOrNames]: + return (expr.array,) map_roll = _map_index_remapping_base map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> abc_Set[ArrayOrNames]: + def _map_input_base(self, expr: InputArgumentBase) -> tuple[ArrayOrNames]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> abc_Set[ArrayOrNames]: + def map_distributed_recv(self, expr: DistributedRecv) -> tuple[ArrayOrNames]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> abc_Set[ArrayOrNames]: - return FrozenOrderedSet([expr.passthrough_data]) + ) -> tuple[ArrayOrNames]: + return (expr.passthrough_data,) + + def map_call(self, expr: Call) -> tuple[ArrayOrNames]: + return tuple(unique(expr.bindings.values())) - def map_call(self, expr: Call) -> abc_Set[ArrayOrNames]: - return FrozenOrderedSet(expr.bindings.values()) + def map_named_call_result( + self, expr: NamedCallResult) -> tuple[ArrayOrNames]: + return (expr._container,) - def map_named_call_result(self, expr: NamedCallResult) -> abc_Set[ArrayOrNames]: - return FrozenOrderedSet([expr._container]) # }}} diff --git a/setup.py b/setup.py index 9fe0df6b1..ba0bd1b4d 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,6 @@ "immutabledict", "attrs", "bidict", - "orderedsets", ], package_data={"pytato": ["py.typed"]}, author="Andreas Kloeckner, Matt Wala, Xiaoyu Wei", From ea1462c0bd71ea8bf2a6d80e26e83fdbc255fc57 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Aug 2024 11:37:53 -0500 Subject: [PATCH 110/178] mypy fixes --- pytato/analysis/__init__.py | 18 +++++++++--------- pytato/distributed/partition.py | 26 +++++++++++++------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 883030a43..c568c8f9c 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -325,24 +325,24 @@ class DirectPredecessorsGetter(Mapper): We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> tuple[ArrayOrNames]: + def _get_preds_from_shape(self, shape: ShapeType) -> tuple[ArrayOrNames, ...]: return tuple(unique(dim for dim in shape if isinstance(dim, Array))) - def map_index_lambda(self, expr: IndexLambda) -> tuple[ArrayOrNames]: + def map_index_lambda(self, expr: IndexLambda) -> tuple[ArrayOrNames, ...]: return tuple(unique(tuple(expr.bindings.values()) + self._get_preds_from_shape(expr.shape))) - def map_stack(self, expr: Stack) -> tuple[ArrayOrNames]: + def map_stack(self, expr: Stack) -> tuple[ArrayOrNames, ...]: return tuple(unique(tuple(expr.arrays) + self._get_preds_from_shape(expr.shape))) map_concatenate = map_stack - def map_einsum(self, expr: Einsum) -> tuple[ArrayOrNames]: + def map_einsum(self, expr: Einsum) -> tuple[ArrayOrNames, ...]: return tuple(unique(tuple(expr.args) + self._get_preds_from_shape(expr.shape))) - def map_loopy_call_result(self, expr: NamedArray) -> tuple[ArrayOrNames]: + def map_loopy_call_result(self, expr: NamedArray) -> tuple[ArrayOrNames, ...]: from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) @@ -351,7 +351,7 @@ def map_loopy_call_result(self, expr: NamedArray) -> tuple[ArrayOrNames]: if isinstance(ary, Array)) + self._get_preds_from_shape(expr.shape))) - def _map_index_base(self, expr: IndexBase) -> tuple[ArrayOrNames]: + def _map_index_base(self, expr: IndexBase) -> tuple[ArrayOrNames, ...]: return tuple(unique((expr.array,) # noqa: RUF005 + tuple(idx for idx in expr.indices if isinstance(idx, Array)) @@ -369,14 +369,14 @@ def _map_index_remapping_base(self, expr: IndexRemappingBase map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> tuple[ArrayOrNames]: + def _map_input_base(self, expr: InputArgumentBase) -> tuple[ArrayOrNames, ...]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> tuple[ArrayOrNames]: + def map_distributed_recv(self, expr: DistributedRecv) -> tuple[ArrayOrNames, ...]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, @@ -384,7 +384,7 @@ def map_distributed_send_ref_holder(self, ) -> tuple[ArrayOrNames]: return (expr.passthrough_data,) - def map_call(self, expr: Call) -> tuple[ArrayOrNames]: + def map_call(self, expr: Call) -> tuple[ArrayOrNames, ...]: return tuple(unique(expr.bindings.values())) def map_named_call_result( diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index e8f4b1fb2..68a924c8a 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -315,8 +315,8 @@ def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: class _PartCommIDs: """A *part*, unlike a *batch*, begins with receives and ends with sends. """ - recv_ids: tuple[CommunicationOpIdentifier] - send_ids: tuple[CommunicationOpIdentifier] + recv_ids: tuple[CommunicationOpIdentifier, ...] + send_ids: tuple[CommunicationOpIdentifier, ...] # {{{ _make_distributed_partition @@ -402,12 +402,12 @@ def _recv_to_comm_id( class _LocalSendRecvDepGatherer( - CombineMapper[tuple[CommunicationOpIdentifier]]): + CombineMapper[tuple[CommunicationOpIdentifier, ...]]): def __init__(self, local_rank: int) -> None: super().__init__() self.local_comm_ids_to_needed_comm_ids: \ dict[CommunicationOpIdentifier, - tuple[CommunicationOpIdentifier]] = {} + tuple[CommunicationOpIdentifier, ...]] = {} self.local_recv_id_to_recv_node: \ dict[CommunicationOpIdentifier, DistributedRecv] = {} @@ -417,14 +417,14 @@ def __init__(self, local_rank: int) -> None: self.local_rank = local_rank def combine( - self, *args: tuple[CommunicationOpIdentifier] - ) -> tuple[CommunicationOpIdentifier]: + self, *args: tuple[CommunicationOpIdentifier, ...] + ) -> tuple[CommunicationOpIdentifier, ...]: from pytools import unique return reduce(lambda x, y: tuple(unique(x+y)), args, ()) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> tuple[CommunicationOpIdentifier]: + ) -> tuple[CommunicationOpIdentifier, ...]: send_id = _send_to_comm_id(self.local_rank, expr.send) if send_id in self.local_send_id_to_send_node: @@ -438,7 +438,7 @@ def map_distributed_send_ref_holder(self, return self.rec(expr.passthrough_data) - def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier]: + def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier, ...]: return () map_placeholder = _map_input_base @@ -447,7 +447,7 @@ def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier]: def map_distributed_recv( self, expr: DistributedRecv - ) -> tuple[CommunicationOpIdentifier]: + ) -> tuple[CommunicationOpIdentifier, ...]: recv_id = _recv_to_comm_id(self.local_rank, expr) if recv_id in self.local_recv_id_to_recv_node: @@ -461,7 +461,7 @@ def map_distributed_recv( return (recv_id,) def map_named_call_result( - self, expr: NamedCallResult) -> tuple[CommunicationOpIdentifier]: + self, expr: NamedCallResult) -> tuple[CommunicationOpIdentifier, ...]: raise NotImplementedError( "LocalSendRecvDepGatherer does not support functions.") @@ -594,8 +594,8 @@ def post_visit(self, expr: Any) -> None: # {{{ _set_dict_union_mpi def _set_dict_union_mpi( - dict_a: Mapping[_KeyT, Sequence[_ValueT]], - dict_b: Mapping[_KeyT, Sequence[_ValueT]], + dict_a: Mapping[_KeyT, tuple[_ValueT, ...]], + dict_b: Mapping[_KeyT, tuple[_ValueT, ...]], mpi_data_type: mpi4py.MPI.Datatype | None) -> Mapping[_KeyT, Sequence[_ValueT]]: assert mpi_data_type is None from pytools import unique @@ -781,7 +781,7 @@ def find_distributed_partition( part_comm_ids: list[_PartCommIDs] = [] if comm_batches: - recv_ids: tuple[CommunicationOpIdentifier] = () + recv_ids: tuple[CommunicationOpIdentifier, ...] = () for batch in comm_batches: send_ids = tuple( comm_id for comm_id in unique(batch) From 168ef532057be4e81649acb9353510a28ed62d84 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Aug 2024 11:48:48 -0500 Subject: [PATCH 111/178] remove unnecesary cast --- pytato/distributed/partition.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 68a924c8a..38f26c8d5 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -771,9 +771,7 @@ def find_distributed_partition( if isinstance(comm_batches_or_exc, Exception): raise comm_batches_or_exc - comm_batches = cast( - Sequence[list[CommunicationOpIdentifier]], - comm_batches_or_exc) + comm_batches = comm_batches_or_exc # }}} From d711989c8cc0198cafdbbd94d294f384667b7c25 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Aug 2024 14:43:43 -0500 Subject: [PATCH 112/178] adjust comment --- pytato/distributed/partition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 38f26c8d5..dea81e925 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -829,7 +829,7 @@ def find_distributed_partition( materialized_arrays_collector = _MaterializedArrayCollector() materialized_arrays_collector(outputs) - # The sets of arrays below must have a deterministic order in order to ensure + # The collections of arrays below must have a deterministic order in order to ensure # that the resulting partition is also deterministic sent_arrays = tuple(unique( From 679f5cdb64e7166ee2acd609fe8bd6a53274aee8 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 20 Aug 2024 12:40:19 -0500 Subject: [PATCH 113/178] Revert "Implement numpy 2 type promotion" This reverts commit f0deaea1130fe0dd4dfa52bee540910305de9c70. --- README.rst | 7 +- pytato/array.py | 153 +++++++++--------------------------------- pytato/scalar_expr.py | 1 - pytato/utils.py | 67 ++++++++---------- test/test_codegen.py | 16 +---- test/test_pytato.py | 4 +- 6 files changed, 67 insertions(+), 181 deletions(-) diff --git a/README.rst b/README.rst index 524d923de..71ebccaa8 100644 --- a/README.rst +++ b/README.rst @@ -32,9 +32,4 @@ Numpy compatibility Pytato is written to pose no particular restrictions on the version of numpy used for execution. To use mypy-based type checking on Pytato itself or packages using Pytato, numpy 1.20 or newer is required, due to the -typing-based changes to numpy in that release. Furthermore, pytato -now uses type promotion rules aiming to match those in -`numpy 2 `__. -This will not break compatibility with older numpy versions, but may -result in differing data types between computations carried out in -numpy and pytato. +typing-based changes to numpy in that release. diff --git a/pytato/array.py b/pytato/array.py index eefe7c69a..2b0f38681 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -302,96 +302,22 @@ def normalize_shape_component( ConvertibleToIndexExpr = Union[int, slice, "Array", None, EllipsisType] IndexExpr = Union[IntegralT, "NormalizedSlice", "Array", None, EllipsisType] -PyScalarType = Union[type[bool], type[int], type[float], type[complex]] -DtypeOrPyScalarType = Union[_dtype_any, PyScalarType] +DtypeOrScalar = Union[_dtype_any, Scalar] ArrayOrScalar = Union["Array", Scalar] -# https://numpy.org/neps/nep-0050-scalar-promotion.html -class DtypeKindCategory(IntEnum): - BOOLEAN = 0 - INTEGRAL = 1 - INEXACT = 2 +# https://github.com/numpy/numpy/issues/19302 +def _np_result_type( + # actual dtype: + #*arrays_and_dtypes: Union[np.typing.ArrayLike, np.typing.DTypeLike], + # our dtype: + *arrays_and_dtypes: DtypeOrScalar, + ) -> np.dtype[Any]: + return np.result_type(*arrays_and_dtypes) -_dtype_kind_char_to_kind_cat = { - "b": DtypeKindCategory.BOOLEAN, - "i": DtypeKindCategory.INTEGRAL, - "u": DtypeKindCategory.INTEGRAL, - "f": DtypeKindCategory.INEXACT, - "c": DtypeKindCategory.INEXACT, -} - - -_py_type_to_kind_cat = { - bool: DtypeKindCategory.BOOLEAN, - int: DtypeKindCategory.INTEGRAL, - float: DtypeKindCategory.INEXACT, - complex: DtypeKindCategory.INEXACT, -} - - -_float_dtype_to_complex: dict[np.dtype[Any], np.dtype[Any]] = { - np.dtype(np.float32): np.dtype(np.complex64), - np.dtype(np.float64): np.dtype(np.complex128), -} - - -def _complexify_dtype(dtype: np.dtype[Any]) -> np.dtype[Any]: - if dtype.kind == "c": - return dtype - elif dtype.kind == "f": - return _float_dtype_to_complex[dtype] - else: - raise ValueError("can only complexify types that are already inexact") - - -def _np_result_dtype(*dtypes: DtypeOrPyScalarType) -> np.dtype[Any]: - # For numpy 2.0, np.result_type does not implement numpy's type - # promotion behavior. Weird. Hence all this nonsense is needed. - - py_types = [dtype for dtype in dtypes if isinstance(dtype, type)] - - if not py_types: - return np.result_type(*dtypes) - - np_dtypes = [dtype for dtype in dtypes if isinstance(dtype, np.dtype)] - np_kind_cats = { - _dtype_kind_char_to_kind_cat[dtype.kind] for dtype in np_dtypes} - py_kind_cats = {_py_type_to_kind_cat[tp] for tp in py_types} - kind_cats = np_kind_cats | py_kind_cats - - res_kind_cat = max(kind_cats) - max_py_kind_cats = max(py_kind_cats) - max_np_kind_cats = max(np_kind_cats) - - is_complex = (complex in py_types - or any(dtype.kind == "c" for dtype in np_dtypes)) - - if max_py_kind_cats > max_np_kind_cats: - if res_kind_cat == DtypeKindCategory.INTEGRAL: - # FIXME: Perhaps this should be int32 "on some systems, e.g. Windows" - py_promotion_dtype: np.dtype[Any] = np.dtype(np.int64) - elif res_kind_cat == DtypeKindCategory.INEXACT: - if is_complex: - py_promotion_dtype = np.dtype(np.complex128) - else: - py_promotion_dtype = np.dtype(np.float64) - else: - # bool won't ever be promoted to - raise AssertionError() - return np.result_type(*([*np_dtypes, py_promotion_dtype])) - - else: - # Just ignore the python types for promotion. - result = np.result_type(*np_dtypes) - if is_complex: - result = _complexify_dtype(result) - return result - - -def _truediv_result_type(*dtypes: DtypeOrPyScalarType) -> np.dtype[Any]: - dtype = _np_result_dtype(*dtypes) +def _truediv_result_type(arg1: DtypeOrScalar, arg2: DtypeOrScalar) -> np.dtype[Any]: + dtype = _np_result_type(arg1, arg2) # See: test_true_divide in numpy/core/tests/test_ufunc.py # pylint: disable=no-member if dtype.kind in "iu": @@ -652,16 +578,11 @@ def __matmul__(self, other: Array, reverse: bool = False) -> Array: __rmatmul__ = partialmethod(__matmul__, reverse=True) - def _binary_op( - self, - op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], - other: ArrayOrScalar, - get_result_type: Callable[ - [DtypeOrPyScalarType, DtypeOrPyScalarType], - np.dtype[Any]] = _np_result_dtype, - reverse: bool = False, - cast_to_result_dtype: bool = True, - ) -> Array: + def _binary_op(self, + op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], + other: ArrayOrScalar, + get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]] = _np_result_type, # noqa + reverse: bool = False) -> Array: # {{{ sanity checks @@ -675,19 +596,15 @@ def _binary_op( import pytato.utils as utils if reverse: - result = utils.broadcast_binary_op( - other, self, op, - get_result_type, - tags=tags, - non_equality_tags=non_equality_tags, - cast_to_result_dtype=cast_to_result_dtype) + result = utils.broadcast_binary_op(other, self, op, + get_result_type, + tags=tags, + non_equality_tags=non_equality_tags) else: - result = utils.broadcast_binary_op( - self, other, op, - get_result_type, - tags=tags, - non_equality_tags=non_equality_tags, - cast_to_result_dtype=cast_to_result_dtype) + result = utils.broadcast_binary_op(self, other, op, + get_result_type, + tags=tags, + non_equality_tags=non_equality_tags) assert isinstance(result, Array) return result @@ -1495,7 +1412,7 @@ class Stack(_SuppliedAxesAndTagsMixin, Array): @property def dtype(self) -> np.dtype[Any]: - return _np_result_dtype(*(arr.dtype for arr in self.arrays)) + return _np_result_type(*(arr.dtype for arr in self.arrays)) @property def shape(self) -> ShapeType: @@ -1528,7 +1445,7 @@ class Concatenate(_SuppliedAxesAndTagsMixin, Array): @property def dtype(self) -> np.dtype[Any]: - return _np_result_dtype(*(arr.dtype for arr in self.arrays)) + return _np_result_type(*(arr.dtype for arr in self.arrays)) @property def shape(self) -> ShapeType: @@ -2149,7 +2066,7 @@ def reshape(array: Array, newshape: int | Sequence[int], """ :param array: array to be reshaped :param newshape: shape of the resulting array - :param order: ``"C"`` or ``"F"``. Layout order of the resulting array. + :param order: ``"C"`` or ``"F"``. Layout order of the resulting array. .. note:: @@ -2491,14 +2408,12 @@ def _compare(x1: ArrayOrScalar, x2: ArrayOrScalar, which: str) -> Array | bool: import pytato.utils as utils # type-ignored because 'broadcast_binary_op' returns Scalar, while # '_compare' returns a bool. - return utils.broadcast_binary_op( - x1, x2, - lambda x, y: prim.Comparison(x, which, y), - lambda x, y: np.dtype(np.bool_), - tags=_get_default_tags(), - non_equality_tags=_get_created_at_tag(stacklevel=2), - cast_to_result_dtype=False - ) # type: ignore[return-value] + return utils.broadcast_binary_op(x1, x2, + lambda x, y: prim.Comparison(x, which, y), + lambda x, y: np.dtype(np.bool_), + tags=_get_default_tags(), + non_equality_tags=_get_created_at_tag(stacklevel=2), + ) # type: ignore[return-value] def equal(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: @@ -2560,7 +2475,6 @@ def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: lambda x, y: np.dtype(np.bool_), tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), - cast_to_result_dtype=False, ) # type: ignore[return-value] @@ -2577,7 +2491,6 @@ def logical_and(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: lambda x, y: np.dtype(np.bool_), tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), - cast_to_result_dtype=False, ) # type: ignore[return-value] diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 989d2c405..606e4752a 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -82,7 +82,6 @@ Scalar = Union[np.number[Any], int, np.bool_, bool, float, complex] ScalarExpression = Union[Scalar, prim.Expression] -PYTHON_SCALAR_CLASSES = (int, float, complex, bool) SCALAR_CLASSES = prim.VALID_CONSTANT_CLASSES diff --git a/pytato/utils.py b/pytato/utils.py index a7f817df4..722a0e3b2 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -31,9 +31,21 @@ TypeVar, ) -import islpy as isl -import numpy as np +from typing import (Tuple, List, Union, Callable, Any, Sequence, Dict, + Optional, Iterable, TypeVar, FrozenSet) +from pytato.array import (Array, ShapeType, IndexLambda, SizeParam, ShapeComponent, + DtypeOrScalar, ArrayOrScalar, BasicIndex, + AdvancedIndexInContiguousAxes, + AdvancedIndexInNoncontiguousAxes, + ConvertibleToIndexExpr, IndexExpr, NormalizedSlice, + _dtype_any, Einsum) +from pytato.scalar_expr import (ScalarExpression, IntegralScalarExpression, + SCALAR_CLASSES, INT_CLASSES, BoolT) +from pytools import UniqueNameGenerator +from pytato.transform import Mapper +from pytools.tag import Tag from immutabledict import immutabledict +import numpy as np import pymbolic.primitives as prim from pytools import UniqueNameGenerator @@ -46,7 +58,6 @@ ArrayOrScalar, BasicIndex, ConvertibleToIndexExpr, - DtypeOrPyScalarType, Einsum, IndexExpr, IndexLambda, @@ -58,7 +69,6 @@ ) from pytato.scalar_expr import ( INT_CLASSES, - PYTHON_SCALAR_CLASSES, SCALAR_CLASSES, BoolT, IntegralScalarExpression, @@ -164,18 +174,15 @@ def with_indices_for_broadcasted_shape(val: prim.Variable, shape: ShapeType, return val[get_indexing_expression(shape, result_shape)] -def _extract_dtypes( - exprs: Sequence[ArrayOrScalar]) -> list[DtypeOrPyScalarType]: - dtypes: list[DtypeOrPyScalarType] = [] +def extract_dtypes_or_scalars( + exprs: Sequence[ArrayOrScalar]) -> List[DtypeOrScalar]: + dtypes: List[DtypeOrScalar] = [] for expr in exprs: if isinstance(expr, Array): dtypes.append(expr.dtype) - elif isinstance(expr, np.generic): - dtypes.append(expr.dtype) - elif isinstance(expr, PYTHON_SCALAR_CLASSES): - dtypes.append(type(expr)) else: - raise TypeError(f"unexpected expression type: '{type(expr)}'") + assert isinstance(expr, SCALAR_CLASSES) + dtypes.append(expr) return dtypes @@ -208,21 +215,24 @@ def update_bindings_and_get_broadcasted_expr(arr: ArrayOrScalar, def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 - get_result_type: Callable[[DtypeOrPyScalarType, DtypeOrPyScalarType], np.dtype[Any]], # noqa:E501 - *, - tags: frozenset[Tag], - non_equality_tags: frozenset[Tag], - cast_to_result_dtype: bool, + get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501 + tags: FrozenSet[Tag], + non_equality_tags: FrozenSet[Tag], ) -> ArrayOrScalar: from pytato.array import _get_default_axes + if isinstance(a1, SCALAR_CLASSES): + a1 = np.dtype(type(a1)).type(a1) + + if isinstance(a2, SCALAR_CLASSES): + a2 = np.dtype(type(a2)).type(a2) + if np.isscalar(a1) and np.isscalar(a2): from pytato.scalar_expr import evaluate return evaluate(op(a1, a2)) # type: ignore result_shape = get_shape_after_broadcasting([a1, a2]) - - dtypes = _extract_dtypes([a1, a2]) + dtypes = extract_dtypes_or_scalars([a1, a2]) result_dtype = get_result_type(*dtypes) bindings: dict[str, Array] = {} @@ -232,25 +242,6 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, expr2 = update_bindings_and_get_broadcasted_expr(a2, "_in1", bindings, result_shape) - def cast_to_result_type( - array: ArrayOrScalar, - expr: ScalarExpression - ) -> ScalarExpression: - if ((isinstance(array, Array) or isinstance(array, np.generic)) - and array.dtype != result_dtype): - # Loopy's type casts don't like casting to bool - assert result_dtype != np.bool_ - - expr = TypeCast(result_dtype, expr) - elif isinstance(expr, SCALAR_CLASSES): - expr = result_dtype.type(expr) - - return expr - - if cast_to_result_dtype: - expr1 = cast_to_result_type(a1, expr1) - expr2 = cast_to_result_type(a2, expr2) - return IndexLambda(expr=op(expr1, expr2), shape=result_shape, dtype=result_dtype, diff --git a/test/test_codegen.py b/test/test_codegen.py index a7c2d615a..661a4092f 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -272,9 +272,6 @@ def wrapper(*args): "logical_or")) @pytest.mark.parametrize("reverse", (False, True)) def test_scalar_array_binary_arith(ctx_factory, which, reverse): - from numpy.lib import NumpyVersion - is_old_numpy = NumpyVersion(np.__version__) < "2.0.0" - cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) not_valid_in_complex = which in ["equal", "not_equal", "less", "less_equal", @@ -319,18 +316,9 @@ def test_scalar_array_binary_arith(ctx_factory, which, reverse): out = outputs[dtype] out_ref = np_op(x_in, y_orig.astype(dtype)) - if not is_old_numpy: - assert out.dtype == out_ref.dtype, (out.dtype, out_ref.dtype) - + assert out.dtype == out_ref.dtype, (out.dtype, out_ref.dtype) # In some cases ops are done in float32 in loopy but float64 in numpy. - is_allclose = np.allclose(out, out_ref), (out, out_ref) - if not is_old_numpy: - assert is_allclose - else: - if out_ref.dtype.itemsize == 1: - pass - else: - assert is_allclose + assert np.allclose(out, out_ref), (out, out_ref) @pytest.mark.parametrize("which", ("add", "sub", "mul", "truediv", "pow", diff --git a/test/test_pytato.py b/test/test_pytato.py index e9c23d8fe..946983ac6 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -514,8 +514,8 @@ def _assert_stripped_repr(ary: pt.Array, expected_repr: str): dtype='int64', expr=Product((Subscript(Variable('_in0'), (Variable('_0'), Variable('_1'))), - TypeCast(dtype('int64'), Subscript(Variable('_in1'), - (Variable('_0'), Variable('_1')))))), + Subscript(Variable('_in1'), + (Variable('_0'), Variable('_1'))))), bindings={'_in0': Placeholder(shape=(10, 4), dtype='int64', name='y'), '_in1': IndexLambda( shape=(10, 4), From a6a91c6f6844bedd57e360254f581b6377470a82 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Thu, 26 Sep 2024 10:24:19 -0500 Subject: [PATCH 114/178] Fix a merge fail --- pytato/array.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 10e2ba727..3867b0641 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -586,11 +586,12 @@ def __matmul__(self, other: Array, reverse: bool = False) -> Array: # ======= def _binary_op( self, - op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], + op: Callable[[ScalarExpression, ScalarExpression], + ScalarExpression], other: ArrayOrScalar, get_result_type: Callable[ [ArrayOrScalar, ArrayOrScalar], - np.dtype[Any]] = _np_result_dtype, + np.dtype[Any]] = _np_result_type, reverse: bool = False, cast_to_result_dtype: bool = True, is_pow: bool = False, From 3b7385e6fa7b5806ebe169491eecda30eb5a3b72 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Thu, 26 Sep 2024 12:25:47 -0500 Subject: [PATCH 115/178] Merge with main --- pytato/array.py | 40 ---------------------------------------- pytato/utils.py | 9 --------- 2 files changed, 49 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 3867b0641..b375ebfb3 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -577,13 +577,6 @@ def __matmul__(self, other: Array, reverse: bool = False) -> Array: __rmatmul__ = partialmethod(__matmul__, reverse=True) -# <<<<<<< HEAD -# def _binary_op(self, -# op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], -# other: ArrayOrScalar, -# get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]] = _np_result_type, # noqa -# reverse: bool = False) -> Array: -# ======= def _binary_op( self, op: Callable[[ScalarExpression, ScalarExpression], @@ -596,7 +589,6 @@ def _binary_op( cast_to_result_dtype: bool = True, is_pow: bool = False, ) -> Array: -# >>>>>>> main # {{{ sanity checks @@ -610,17 +602,6 @@ def _binary_op( import pytato.utils as utils if reverse: -# <<<<<<< HEAD -# result = utils.broadcast_binary_op(other, self, op, -# get_result_type, -# tags=tags, -# non_equality_tags=non_equality_tags) -# else: -# result = utils.broadcast_binary_op(self, other, op, -# get_result_type, -# tags=tags, -# non_equality_tags=non_equality_tags) -# ======= result = utils.broadcast_binary_op( other, self, op, get_result_type, @@ -636,7 +617,6 @@ def _binary_op( non_equality_tags=non_equality_tags, cast_to_result_dtype=cast_to_result_dtype, is_pow=is_pow) -# >>>>>>> main assert isinstance(result, Array) return result @@ -688,13 +668,8 @@ def _unary_op(self, op: Any) -> Array: __rtruediv__ = partialmethod(_binary_op, prim.Quotient, get_result_type=_truediv_result_type, reverse=True) -# <<<<<<< HEAD -# __pow__ = partialmethod(_binary_op, prim.Power) -# __rpow__ = partialmethod(_binary_op, prim.Power, reverse=True) -#======= __pow__ = partialmethod(_binary_op, operator.pow, is_pow=True) __rpow__ = partialmethod(_binary_op, operator.pow, reverse=True, is_pow=True) -# >>>>>>> main __neg__ = partialmethod(_unary_op, operator.neg) @@ -2443,14 +2418,6 @@ def _compare(x1: ArrayOrScalar, x2: ArrayOrScalar, which: str) -> Array | bool: import pytato.utils as utils # type-ignored because 'broadcast_binary_op' returns Scalar, while # '_compare' returns a bool. -# <<<<<<< HEAD -# return utils.broadcast_binary_op(x1, x2, -# lambda x, y: prim.Comparison(x, which, y), -# lambda x, y: np.dtype(np.bool_), -# tags=_get_default_tags(), -# non_equality_tags=_get_created_at_tag(stacklevel=2), -# ) # type: ignore[return-value] -# ======= return utils.broadcast_binary_op( x1, x2, lambda x, y: prim.Comparison(x, which, y), @@ -2460,7 +2427,6 @@ def _compare(x1: ArrayOrScalar, x2: ArrayOrScalar, which: str) -> Array | bool: cast_to_result_dtype=False, is_pow=False, ) # type: ignore[return-value] -# >>>>>>> main def equal(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: @@ -2522,11 +2488,8 @@ def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: lambda x, y: np.dtype(np.bool_), tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), -# <<<<<<< HEAD -# ======= cast_to_result_dtype=False, is_pow=False, -# >>>>>>> main ) # type: ignore[return-value] @@ -2543,11 +2506,8 @@ def logical_and(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: lambda x, y: np.dtype(np.bool_), tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), -# <<<<<<< HEAD -# ======= cast_to_result_dtype=False, is_pow=False, -# >>>>>>> main ) # type: ignore[return-value] diff --git a/pytato/utils.py b/pytato/utils.py index 9b429f097..9cb4a9d92 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -215,18 +215,12 @@ def update_bindings_and_get_broadcasted_expr(arr: ArrayOrScalar, def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 -# <<<<<<< HEAD -# get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501 -# tags: FrozenSet[Tag], -# non_equality_tags: FrozenSet[Tag], -# ======= get_result_type: Callable[[ArrayOrScalar, ArrayOrScalar], np.dtype[Any]], # noqa:E501 *, tags: frozenset[Tag], non_equality_tags: frozenset[Tag], cast_to_result_dtype: bool, is_pow: bool, -# >>>>>>> main ) -> ArrayOrScalar: from pytato.array import _get_default_axes @@ -251,8 +245,6 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, expr2 = update_bindings_and_get_broadcasted_expr(a2, "_in1", bindings, result_shape) -# <<<<<<< HEAD -# ======= def cast_to_result_type( array: ArrayOrScalar, expr: ScalarExpression @@ -282,7 +274,6 @@ def cast_to_result_type( expr1 = cast_to_result_type(a1, expr1) expr2 = cast_to_result_type(a2, expr2) -# >>>>>>> main return IndexLambda(expr=op(expr1, expr2), shape=result_shape, dtype=result_dtype, From 58478004b85604d9ac0b238ab55d601326ef3b94 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 27 Sep 2024 13:01:19 -0500 Subject: [PATCH 116/178] performance fix --- pytato/distributed/partition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index dea81e925..03691e6ae 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -928,7 +928,7 @@ def find_distributed_partition( def get_materialized_predecessors(ary: Array) -> tuple[Array, ...]: materialized_preds: dict[Array, None] = {} for pred in direct_preds_getter(ary): - if pred in materialized_arrays: + if pred in materialized_arrays_set: materialized_preds[pred] = None else: for p in get_materialized_predecessors(pred): From 7dd83bb7ab25e701f3a9b5d05477d991d0fbb1fa Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 27 Sep 2024 14:22:21 -0500 Subject: [PATCH 117/178] switch to dicts --- pytato/analysis/__init__.py | 62 ++++++++++++++++---------------- pytato/distributed/partition.py | 63 +++++++++++++++------------------ 2 files changed, 61 insertions(+), 64 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index c568c8f9c..e721a7a8a 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any, Mapping from pymbolic.mapper.optimize import optimize_mapper -from pytools import memoize_method, unique +from pytools import memoize_method from pytato.array import ( Array, @@ -325,71 +325,73 @@ class DirectPredecessorsGetter(Mapper): We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> tuple[ArrayOrNames, ...]: - return tuple(unique(dim for dim in shape if isinstance(dim, Array))) + def _get_preds_from_shape(self, shape: ShapeType) -> dict[Array, None]: + return dict.fromkeys(dim for dim in shape if isinstance(dim, Array)) - def map_index_lambda(self, expr: IndexLambda) -> tuple[ArrayOrNames, ...]: - return tuple(unique(tuple(expr.bindings.values()) - + self._get_preds_from_shape(expr.shape))) + def map_index_lambda(self, expr: IndexLambda) -> dict[Array, None]: + return (dict.fromkeys(expr.bindings.values()) + | self._get_preds_from_shape(expr.shape)) - def map_stack(self, expr: Stack) -> tuple[ArrayOrNames, ...]: - return tuple(unique(tuple(expr.arrays) - + self._get_preds_from_shape(expr.shape))) + def map_stack(self, expr: Stack) -> dict[Array, None]: + return (dict.fromkeys(expr.arrays) + | self._get_preds_from_shape(expr.shape)) - map_concatenate = map_stack + def map_concatenate(self, expr: Concatenate) -> dict[Array, None]: + return (dict.fromkeys(expr.arrays) + | self._get_preds_from_shape(expr.shape)) - def map_einsum(self, expr: Einsum) -> tuple[ArrayOrNames, ...]: - return tuple(unique(tuple(expr.args) - + self._get_preds_from_shape(expr.shape))) + def map_einsum(self, expr: Einsum) -> dict[Array, None]: + return (dict.fromkeys(expr.args) + | self._get_preds_from_shape(expr.shape)) - def map_loopy_call_result(self, expr: NamedArray) -> tuple[ArrayOrNames, ...]: + def map_loopy_call_result(self, expr: NamedArray) -> dict[Array, None]: from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) - return tuple(unique(tuple(ary + return (dict.fromkeys(ary for ary in expr._container.bindings.values() if isinstance(ary, Array)) - + self._get_preds_from_shape(expr.shape))) + | self._get_preds_from_shape(expr.shape)) - def _map_index_base(self, expr: IndexBase) -> tuple[ArrayOrNames, ...]: - return tuple(unique((expr.array,) # noqa: RUF005 - + tuple(idx for idx in expr.indices + def _map_index_base(self, expr: IndexBase) -> dict[Array, None]: + return (dict.fromkeys([expr.array]) + | dict.fromkeys(idx for idx in expr.indices if isinstance(idx, Array)) - + self._get_preds_from_shape(expr.shape))) + | self._get_preds_from_shape(expr.shape)) map_basic_index = _map_index_base map_contiguous_advanced_index = _map_index_base map_non_contiguous_advanced_index = _map_index_base def _map_index_remapping_base(self, expr: IndexRemappingBase - ) -> tuple[ArrayOrNames]: - return (expr.array,) + ) -> dict[ArrayOrNames, None]: + return dict.fromkeys([expr.array]) map_roll = _map_index_remapping_base map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> tuple[ArrayOrNames, ...]: + def _map_input_base(self, expr: InputArgumentBase) -> dict[Array, None]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> tuple[ArrayOrNames, ...]: + def map_distributed_recv(self, expr: DistributedRecv) -> dict[Array, None]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> tuple[ArrayOrNames]: - return (expr.passthrough_data,) + ) -> dict[ArrayOrNames, None]: + return dict.fromkeys([expr.passthrough_data]) - def map_call(self, expr: Call) -> tuple[ArrayOrNames, ...]: - return tuple(unique(expr.bindings.values())) + def map_call(self, expr: Call) -> dict[ArrayOrNames, None]: + return dict.fromkeys(expr.bindings.values()) def map_named_call_result( - self, expr: NamedCallResult) -> tuple[ArrayOrNames]: - return (expr._container,) + self, expr: NamedCallResult) -> dict[ArrayOrNames, None]: + return dict.fromkeys([expr._container]) # }}} diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 03691e6ae..9e9f47913 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -315,8 +315,8 @@ def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: class _PartCommIDs: """A *part*, unlike a *batch*, begins with receives and ends with sends. """ - recv_ids: tuple[CommunicationOpIdentifier, ...] - send_ids: tuple[CommunicationOpIdentifier, ...] + recv_ids: immutabledict[CommunicationOpIdentifier, None] + send_ids: immutabledict[CommunicationOpIdentifier, None] # {{{ _make_distributed_partition @@ -727,8 +727,7 @@ def find_distributed_partition( assigned in :attr:`DistributedGraphPart.name_to_send_nodes`. """ import mpi4py.MPI as MPI - - from pytools import unique + from immutabledict import immutabledict from pytato.transform import SubsetDependencyMapper @@ -779,30 +778,31 @@ def find_distributed_partition( part_comm_ids: list[_PartCommIDs] = [] if comm_batches: - recv_ids: tuple[CommunicationOpIdentifier, ...] = () + recv_ids: immutabledict[CommunicationOpIdentifier, None] = immutabledict() for batch in comm_batches: - send_ids = tuple( - comm_id for comm_id in unique(batch) - if comm_id.src_rank == local_rank) + send_ids: immutabledict[CommunicationOpIdentifier, None] \ + = immutabledict.fromkeys( + comm_id for comm_id in batch + if comm_id.src_rank == local_rank) if recv_ids or send_ids: part_comm_ids.append( _PartCommIDs( recv_ids=recv_ids, send_ids=send_ids)) # These go into the next part - recv_ids = tuple( - comm_id for comm_id in unique(batch) + recv_ids = immutabledict.fromkeys( + comm_id for comm_id in batch if comm_id.dest_rank == local_rank) if recv_ids: part_comm_ids.append( _PartCommIDs( recv_ids=recv_ids, - send_ids=())) + send_ids=immutabledict())) else: part_comm_ids.append( _PartCommIDs( - recv_ids=(), - send_ids=())) + recv_ids=immutabledict(), + send_ids=immutabledict())) nparts = len(part_comm_ids) @@ -820,7 +820,7 @@ def find_distributed_partition( comm_id_to_part_id = { comm_id: ipart for ipart, comm_ids in enumerate(part_comm_ids) - for comm_id in unique(comm_ids.send_ids + comm_ids.recv_ids)} + for comm_id in comm_ids.send_ids | comm_ids.recv_ids} # }}} @@ -832,10 +832,10 @@ def find_distributed_partition( # The collections of arrays below must have a deterministic order in order to ensure # that the resulting partition is also deterministic - sent_arrays = tuple(unique( - send_node.data for send_node in lsrdg.local_send_id_to_send_node.values())) + sent_arrays = dict.fromkeys( + send_node.data for send_node in lsrdg.local_send_id_to_send_node.values()) - received_arrays = tuple(unique(lsrdg.local_recv_id_to_recv_node.values())) + received_arrays = dict.fromkeys(lsrdg.local_recv_id_to_recv_node.values()) # While receive nodes may be marked as materialized, we shouldn't be # including them here because we're using them (along with the send nodes) @@ -843,18 +843,13 @@ def find_distributed_partition( # We could allow sent *arrays* to be included here because they are distinct # from send *nodes*, but we choose to exclude them in order to simplify the # processing below. - materialized_arrays_set = set(materialized_arrays_collector.materialized_arrays) \ - - set(received_arrays) \ - - set(sent_arrays) - - from pytools import unique - materialized_arrays = tuple(unique( - a for a in materialized_arrays_collector.materialized_arrays - if a in materialized_arrays_set)) + materialized_arrays = {a: None + for a in materialized_arrays_collector.materialized_arrays + if a not in received_arrays | sent_arrays} # "mso" for "materialized/sent/output" - output_arrays = tuple(unique(outputs._data.values())) - mso_arrays = tuple(unique(materialized_arrays + sent_arrays + output_arrays)) + output_arrays = dict.fromkeys(outputs._data.values()) + mso_arrays = materialized_arrays | sent_arrays | output_arrays # FIXME: This gathers up materialized_arrays recursively, leading to # result sizes potentially quadratic in the number of materialized arrays. @@ -918,30 +913,30 @@ def find_distributed_partition( assert all(0 <= part_id < nparts for part_id in stored_ary_to_part_id.values()) - stored_arrays = tuple(unique(stored_ary_to_part_id)) + stored_arrays = dict.fromkeys(stored_ary_to_part_id) # {{{ find which stored arrays should become part outputs # (because they are used in not just their local part, but also others) direct_preds_getter = DirectPredecessorsGetter() - def get_materialized_predecessors(ary: Array) -> tuple[Array, ...]: + def get_materialized_predecessors(ary: Array) -> dict[Array, None]: materialized_preds: dict[Array, None] = {} for pred in direct_preds_getter(ary): - if pred in materialized_arrays_set: + if pred in materialized_arrays: materialized_preds[pred] = None else: for p in get_materialized_predecessors(pred): materialized_preds[p] = None - return tuple(materialized_preds.keys()) + return materialized_preds - stored_arrays_promoted_to_part_outputs = tuple(unique( - stored_pred + stored_arrays_promoted_to_part_outputs = { + stored_pred: None for stored_ary in stored_arrays for stored_pred in get_materialized_predecessors(stored_ary) if (stored_ary_to_part_id[stored_ary] != stored_ary_to_part_id[stored_pred]) - )) + } # }}} From 1ea962cc9460b6407e453dd5bebfc28cc02830ae Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 27 Sep 2024 14:40:51 -0500 Subject: [PATCH 118/178] more dict usage --- pytato/distributed/partition.py | 43 ++++++++++++++++----------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 9e9f47913..111f07d2e 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -402,12 +402,12 @@ def _recv_to_comm_id( class _LocalSendRecvDepGatherer( - CombineMapper[tuple[CommunicationOpIdentifier, ...]]): + CombineMapper[dict[CommunicationOpIdentifier, None]]): def __init__(self, local_rank: int) -> None: super().__init__() self.local_comm_ids_to_needed_comm_ids: \ dict[CommunicationOpIdentifier, - tuple[CommunicationOpIdentifier, ...]] = {} + dict[CommunicationOpIdentifier, None]] = {} self.local_recv_id_to_recv_node: \ dict[CommunicationOpIdentifier, DistributedRecv] = {} @@ -417,14 +417,13 @@ def __init__(self, local_rank: int) -> None: self.local_rank = local_rank def combine( - self, *args: tuple[CommunicationOpIdentifier, ...] - ) -> tuple[CommunicationOpIdentifier, ...]: - from pytools import unique - return reduce(lambda x, y: tuple(unique(x+y)), args, ()) + self, *args: dict[CommunicationOpIdentifier, None] + ) -> dict[CommunicationOpIdentifier, None]: + return reduce(lambda x, y: x | y, args, {}) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> tuple[CommunicationOpIdentifier, ...]: + ) -> dict[CommunicationOpIdentifier, None]: send_id = _send_to_comm_id(self.local_rank, expr.send) if send_id in self.local_send_id_to_send_node: @@ -438,8 +437,8 @@ def map_distributed_send_ref_holder(self, return self.rec(expr.passthrough_data) - def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier, ...]: - return () + def _map_input_base(self, expr: Array) -> dict[CommunicationOpIdentifier, None]: + return {} map_placeholder = _map_input_base map_data_wrapper = _map_input_base @@ -447,21 +446,21 @@ def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier, ...]: def map_distributed_recv( self, expr: DistributedRecv - ) -> tuple[CommunicationOpIdentifier, ...]: + ) -> dict[CommunicationOpIdentifier, None]: recv_id = _recv_to_comm_id(self.local_rank, expr) if recv_id in self.local_recv_id_to_recv_node: from pytato.distributed.verify import DuplicateRecvError raise DuplicateRecvError(f"Multiple receives found for '{recv_id}'") - self.local_comm_ids_to_needed_comm_ids[recv_id] = () + self.local_comm_ids_to_needed_comm_ids[recv_id] = {} self.local_recv_id_to_recv_node[recv_id] = expr - return (recv_id,) + return {recv_id: None} def map_named_call_result( - self, expr: NamedCallResult) -> tuple[CommunicationOpIdentifier, ...]: + self, expr: NamedCallResult) -> dict[CommunicationOpIdentifier, None]: raise NotImplementedError( "LocalSendRecvDepGatherer does not support functions.") @@ -475,7 +474,7 @@ def map_named_call_result( def _schedule_task_batches( task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ - -> Sequence[list[TaskType]]: + -> Sequence[dict[TaskType, None]]: """For each :type:`TaskType`, determine the 'round'/'batch' during which it will be performed. A 'batch' of tasks consists of tasks which do not depend on each other. @@ -490,7 +489,7 @@ def _schedule_task_batches( def _schedule_task_batches_counted( task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ - -> tuple[Sequence[list[TaskType]], int]: + -> tuple[Sequence[dict[TaskType, None]], int]: """ Static type checkers need the functions to return the same type regardless of the input. The testing code needs to know about the number of tasks visited @@ -499,11 +498,11 @@ def _schedule_task_batches_counted( task_to_dep_level, visits_in_depend = \ _calculate_dependency_levels(task_ids_to_needed_task_ids) nlevels = 1 + max(task_to_dep_level.values(), default=-1) - task_batches: Sequence[list[TaskType]] = [[] for _ in range(nlevels)] + task_batches: Sequence[dict[TaskType, None]] = [{} for _ in range(nlevels)] for task_id, dep_level in task_to_dep_level.items(): if task_id not in task_batches[dep_level]: - task_batches[dep_level].append(task_id) + task_batches[dep_level][task_id] = None return task_batches, visits_in_depend + len(task_to_dep_level.keys()) @@ -594,14 +593,14 @@ def post_visit(self, expr: Any) -> None: # {{{ _set_dict_union_mpi def _set_dict_union_mpi( - dict_a: Mapping[_KeyT, tuple[_ValueT, ...]], - dict_b: Mapping[_KeyT, tuple[_ValueT, ...]], - mpi_data_type: mpi4py.MPI.Datatype | None) -> Mapping[_KeyT, Sequence[_ValueT]]: + dict_a: Mapping[_KeyT, dict[_ValueT, None]], + dict_b: Mapping[_KeyT, dict[_ValueT, None]], + mpi_data_type: mpi4py.MPI.Datatype | None) \ + -> Mapping[_KeyT, dict[_ValueT, None]]: assert mpi_data_type is None - from pytools import unique result = dict(dict_a) for key, values in dict_b.items(): - result[key] = tuple(unique(result.get(key, ()) + values)) + result[key] = result.get(key, {}) | values return result # }}} From 99f4d10ae535dc93afe5b1304966a808e33ca938 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Mon, 11 Nov 2024 11:12:22 -0600 Subject: [PATCH 119/178] Import union --- pytato/scalar_expr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index bfbda9e71..bc137def8 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -50,6 +50,7 @@ TYPE_CHECKING, Any, Never, + Union, cast, ) From cede10c40d803590f326f22e890c7243ca95d487 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Mon, 11 Nov 2024 11:35:15 -0600 Subject: [PATCH 120/178] Use IntegralT --> IntegerT --- pytato/array.py | 2 +- pytato/scalar_expr.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 2c197b552..a58338e12 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -297,7 +297,7 @@ def normalize_shape_component( # {{{ array interface ConvertibleToIndexExpr = Union[int, slice, "Array", None, EllipsisType] -IndexExpr = Union[IntegralT, "NormalizedSlice", "Array", None, EllipsisType] +IndexExpr = Union[IntegerT, "NormalizedSlice", "Array", None, EllipsisType] DtypeOrScalar = Union[_dtype_any, Scalar] ArrayOrScalar = Union["Array", Scalar] PyScalarType = type[bool] | type[int] | type[float] | type[complex] diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index bc137def8..58564b2fb 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -85,7 +85,7 @@ # {{{ scalar expressions INT_CLASSES = (int, np.integer) -IntegralScalarExpression = Union[IntegralT, prim.Expression] +IntegralScalarExpression = Union[IntegerT, prim.Expression] Scalar = Union[np.number[Any], int, np.bool_, bool, float, complex] ScalarExpression = Union[Scalar, prim.Expression] PYTHON_SCALAR_CLASSES = (int, float, complex, bool) From 02b1980cafbf0fa124ed43db218df467cd8a5983 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Mon, 11 Nov 2024 11:45:27 -0600 Subject: [PATCH 121/178] Use Scalar --> ScalarT --- pytato/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index a58338e12..805b3671e 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -298,8 +298,8 @@ def normalize_shape_component( ConvertibleToIndexExpr = Union[int, slice, "Array", None, EllipsisType] IndexExpr = Union[IntegerT, "NormalizedSlice", "Array", None, EllipsisType] -DtypeOrScalar = Union[_dtype_any, Scalar] -ArrayOrScalar = Union["Array", Scalar] +DtypeOrScalar = Union[_dtype_any, ScalarT] +ArrayOrScalar = Union["Array", ScalarT] PyScalarType = type[bool] | type[int] | type[float] | type[complex] DtypeOrPyScalarType = _dtype_any | PyScalarType From b820049160910fa48b1a2fd2efcd757b787ed113 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Mon, 11 Nov 2024 15:26:50 -0600 Subject: [PATCH 122/178] Disable update_for_pymbolic_expression --- pytato/analysis/__init__.py | 44 ++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 8577961d0..9c156786d 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -584,28 +584,28 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: # CL Array self.rec(key_hash, key.get()) - update_for_BitwiseAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_BitwiseNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_BitwiseOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_BitwiseXor = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_Call = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_CallWithKwargs = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_Comparison = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_If = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_FloorDiv = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_LeftShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_LogicalAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_LogicalNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_LogicalOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_Lookup = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_Power = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_Product = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_Quotient = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_Remainder = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_RightShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_Subscript = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_Sum = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_Variable = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_BitwiseAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_BitwiseNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_BitwiseOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_BitwiseXor = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Call = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + # update_for_CallWithKwargs = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Comparison = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_If = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + # update_for_FloorDiv = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_LeftShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_LogicalAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_LogicalNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_LogicalOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Lookup = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Power = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + # update_for_Product = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + # update_for_Quotient = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Remainder = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_RightShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Subscript = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Sum = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + # update_for_Variable = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 # }}} From 1337590a29acb3fcc3b487734a200eda32a1ebc5 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 12 Nov 2024 13:22:16 -0600 Subject: [PATCH 123/178] remove duplicate Hashable --- pytato/distributed/partition.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 43f9d9197..86fe5df51 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -69,9 +69,6 @@ TYPE_CHECKING, Any, Generic, - Hashable, - Mapping, - Sequence, TypeVar, cast, ) From 2c001d6be6b876e2471da2f2261eff5e081e753d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 12 Nov 2024 13:23:56 -0600 Subject: [PATCH 124/178] add missing import --- pytato/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytato/utils.py b/pytato/utils.py index a8f9a2086..e243a20ee 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -44,6 +44,7 @@ from pytools.tag import Tag from immutabledict import immutabledict import numpy as np +import islpy as isl import pymbolic.primitives as prim from pymbolic import ScalarT From b8c0caeeb4760928d92a606ba1d0905caacc3d13 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 21 Nov 2024 16:45:48 -0600 Subject: [PATCH 125/178] do not pickle cached hash of arrays --- pytato/array.py | 44 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index b637197d9..95c5e7a59 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -336,15 +336,34 @@ def _augment_array_dataclass( cls: type, generate_hash: bool, ) -> None: - from dataclasses import fields - attr_tuple = ", ".join(f"self.{fld.name}" - for fld in fields(cls) if fld.name != "non_equality_tags") - if attr_tuple: - attr_tuple = f"({attr_tuple},)" - else: - attr_tuple = "()" + + # {{{ hashing and hash caching if generate_hash: + from dataclasses import fields + + attr_tuple = ", ".join(f"self.{fld.name}" for fld in fields(cls)) + if attr_tuple: + attr_tuple = f"({attr_tuple},)" + else: + attr_tuple = "()" + + fld_name_tuple = ", ".join(f"'{fld.name}'" for fld in fields(cls)) + if fld_name_tuple: + fld_name_tuple = f"({fld_name_tuple},)" + else: + fld_name_tuple = "()" + + # Non-equality tags are automatically excluded from equality in + # EqualityComparer, and are excluded here from hashing. + attr_tuple_hash = ", ".join(f"self.{fld.name}" + for fld in fields(cls) if fld.name != "non_equality_tags") + + if attr_tuple_hash: + attr_tuple_hash = f"({attr_tuple_hash},)" + else: + attr_tuple_hash = "()" + from pytools.codegen import remove_common_indentation augment_code = remove_common_indentation( f""" @@ -354,17 +373,26 @@ def {cls.__name__}_hash(self): except AttributeError: pass - h = hash(frozenset({attr_tuple})) + h = hash(frozenset({attr_tuple_hash})) object.__setattr__(self, "_hash_value", h) return h cls.__hash__ = {cls.__name__}_hash + + def {cls.__name__}_getstate(self): + # This must not return the cached hash value. + return {{name: val + for name, val in zip({fld_name_tuple}, {attr_tuple}, strict=True)}} + + cls.__getstate__ = {cls.__name__}_getstate """) exec_dict = {"cls": cls, "_MODULE_SOURCE_CODE": augment_code} exec(compile(augment_code, f"", "exec"), exec_dict) + # }}} + # {{{ assign mapper_method mm_cls = cast(type[_HasMapperMethod], cls) From ab456e8fb96ad3564e9417d70976ca33c4d93660 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 21 Nov 2024 19:22:45 -0600 Subject: [PATCH 126/178] no clue why this works --- pytato/array.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 95c5e7a59..97a82e557 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -367,6 +367,8 @@ def _augment_array_dataclass( from pytools.codegen import remove_common_indentation augment_code = remove_common_indentation( f""" + import dataclasses + def {cls.__name__}_hash(self): try: return self._hash_value @@ -379,12 +381,8 @@ def {cls.__name__}_hash(self): cls.__hash__ = {cls.__name__}_hash - def {cls.__name__}_getstate(self): - # This must not return the cached hash value. - return {{name: val - for name, val in zip({fld_name_tuple}, {attr_tuple}, strict=True)}} - - cls.__getstate__ = {cls.__name__}_getstate + cls.__getstate__ = dataclasses._dataclass_getstate + cls.__setstate__ = dataclasses._dataclass_setstate """) exec_dict = {"cls": cls, "_MODULE_SOURCE_CODE": augment_code} exec(compile(augment_code, From d5fcd6d9b3418e798ff0cfa7a90dfc32da944e01 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 21 Nov 2024 19:29:58 -0600 Subject: [PATCH 127/178] remove unneeded changes --- pytato/array.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 97a82e557..751d85507 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -342,18 +342,6 @@ def _augment_array_dataclass( if generate_hash: from dataclasses import fields - attr_tuple = ", ".join(f"self.{fld.name}" for fld in fields(cls)) - if attr_tuple: - attr_tuple = f"({attr_tuple},)" - else: - attr_tuple = "()" - - fld_name_tuple = ", ".join(f"'{fld.name}'" for fld in fields(cls)) - if fld_name_tuple: - fld_name_tuple = f"({fld_name_tuple},)" - else: - fld_name_tuple = "()" - # Non-equality tags are automatically excluded from equality in # EqualityComparer, and are excluded here from hashing. attr_tuple_hash = ", ".join(f"self.{fld.name}" From 85920d61882cb6104e4d39ac58f49a6554401078 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Sun, 24 Nov 2024 15:03:53 -0600 Subject: [PATCH 128/178] Update from merge errors --- pytato/array.py | 6 ++++-- pytato/scalar_expr.py | 2 +- pytato/utils.py | 4 +++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index ae1c7f206..2e6b44f22 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -411,8 +411,10 @@ def {cls.__name__}_hash(self): ConvertibleToIndexExpr = Union[int, slice, "Array", EllipsisType, None] # IndexExpr = Union[IntegerT, "NormalizedSlice", "Array", None, EllipsisType] # IndexExpr = Union[IntegerT, "NormalizedSlice", "Array", None] -DtypeOrScalar = Union[_dtype_any, ScalarT] -ArrayOrScalar = Union["Array", ScalarT] +# DtypeOrScalar = Union[_dtype_any, ScalarT] +# ArrayOrScalar = Union["Array", ScalarT] +DtypeOrScalar = Union[_dtype_any, Scalar] +ArrayOrScalar = Union["Array", Scalar] IndexExpr = Union[Integer, "NormalizedSlice", "Array", None] PyScalarType = type[bool] | type[int] | type[float] | type[complex] DtypeOrPyScalarType = _dtype_any | PyScalarType diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 545995056..941e00c6d 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -85,7 +85,7 @@ # {{{ scalar expressions INT_CLASSES = (int, np.integer) -IntegralScalarExpression = Union[IntegerT, prim.Expression] +# IntegralScalarExpression = Union[IntegerT, prim.Expression] Scalar = Union[np.number[Any], int, np.bool_, bool, float, complex] ScalarExpression = Union[Scalar, prim.Expression] PYTHON_SCALAR_CLASSES = (int, float, complex, bool) diff --git a/pytato/utils.py b/pytato/utils.py index c65d94dcb..58d81e9b6 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -37,8 +37,10 @@ AdvancedIndexInNoncontiguousAxes, ConvertibleToIndexExpr, IndexExpr, NormalizedSlice, _dtype_any, Einsum) +#from pytato.scalar_expr import (ScalarExpression, IntegralScalarExpression, +# SCALAR_CLASSES, INT_CLASSES, BoolT) from pytato.scalar_expr import (ScalarExpression, IntegralScalarExpression, - SCALAR_CLASSES, INT_CLASSES, BoolT) + SCALAR_CLASSES, INT_CLASSES) from pytools import UniqueNameGenerator from pytato.transform import Mapper from pytools.tag import Tag From 4d7a8319ce0e81474b04ac35463350fd5127ea67 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Mon, 2 Dec 2024 13:33:35 -0600 Subject: [PATCH 129/178] Test this quick work-around for problem with -O --- pytato/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index 439966f87..f378ccd4c 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -315,7 +315,7 @@ def array_dataclass(hash: bool = True) -> Callable[[type[T]], type[T]]: def map_cls(cls: type[T]) -> type[T]: # Frozen dataclasses (empirically) have a ~20% speed penalty, # and their frozen-ness is arguably a debug feature. - dc_cls = dataclasses.dataclass(init=True, frozen=__debug__, + dc_cls = dataclasses.dataclass(init=True, frozen=True, eq=False, repr=False)(cls) _augment_array_dataclass(dc_cls, generate_hash=hash) From f0a4bbeef9aaa314f29066c69cbd36bbdd3a4463 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Thu, 23 Jan 2025 13:27:07 -0600 Subject: [PATCH 130/178] Repair merge mistake --- pytato/distributed/partition.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 2d5734f88..37a5bfbf7 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -321,10 +321,10 @@ def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: class _PartCommIDs: """A *part*, unlike a *batch*, begins with receives and ends with sends. """ - recv_ids: immutabledict[CommunicationOpIdentifier, None] - send_ids: immutabledict[CommunicationOpIdentifier, None] - # recv_ids: FrozenOrderedSet[CommunicationOpIdentifier] - # send_ids: FrozenOrderedSet[CommunicationOpIdentifier] + # recv_ids: immutabledict[CommunicationOpIdentifier, None] + # send_ids: immutabledict[CommunicationOpIdentifier, None] + recv_ids: FrozenOrderedSet[CommunicationOpIdentifier] + send_ids: FrozenOrderedSet[CommunicationOpIdentifier] # {{{ _make_distributed_partition From 949d438971e6cf451f7ebec69b2d10f3dc3742e2 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Thu, 23 Jan 2025 14:19:50 -0600 Subject: [PATCH 131/178] Do not direct set OrderedSet member. --- pytato/distributed/partition.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 37a5bfbf7..bab4ae167 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -519,7 +519,8 @@ def _schedule_task_batches_counted( for task_id, dep_level in task_to_dep_level.items(): if task_id not in task_batches[dep_level]: - task_batches[dep_level][task_id] = None + # task_batches[dep_level][task_id] = None + task_batches[dep_level].add(task_id) return task_batches, visits_in_depend + len(task_to_dep_level.keys()) @@ -592,15 +593,18 @@ def post_visit(self, expr: Any) -> None: from pytato.tags import ImplStored if (isinstance(expr, Array) and expr.tags_of_type(ImplStored)): - self.materialized_arrays[expr] = None + # self.materialized_arrays[expr] = None + self.materialized_arrays.add(expr) if isinstance(expr, LoopyCallResult): - self.materialized_arrays[expr] = None + # self.materialized_arrays[expr] = None + self.materialized_arrays.add(expr) from pytato.loopy import LoopyCall assert isinstance(expr._container, LoopyCall) for _, subexpr in sorted(expr._container.bindings.items()): if isinstance(subexpr, Array): - self.materialized_arrays[subexpr] = None + # self.materialized_arrays[subexpr] = None + self.materialized_arrays.add(subexpr) else: assert isinstance(subexpr, SCALAR_CLASSES) @@ -943,10 +947,12 @@ def get_materialized_predecessors(ary: Array) -> OrderedSet[Array]: for pred in direct_preds_getter(ary): assert isinstance(pred, Array) if pred in materialized_arrays: - materialized_preds[pred] = None + # materialized_preds[pred] = None + materialized_preds.add(pred) else: for p in get_materialized_predecessors(pred): - materialized_preds[p] = None + # materialized_preds[p] = None + materialized_preds.add(p) return materialized_preds stored_arrays_promoted_to_part_outputs = FrozenOrderedSet( From 2014c174bebe0eb80b33ee72d16b55153a3df2cf Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 7 Feb 2025 16:16:10 -0600 Subject: [PATCH 132/178] disable default implementation of get_cache_key and get_function_definition_cache_key for extra args case ambiguous due to the fact that any arg can be specified with/without keyword --- pytato/transform/__init__.py | 15 +++++++++++---- pytato/transform/einsum_distributive_law.py | 9 +++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index dc1045f25..c8f195d86 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -406,13 +406,20 @@ def __init__( def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs - ) -> Hashable: - return (expr, *args, tuple(sorted(kwargs.items()))) + ) -> CacheKeyT: + if args or kwargs: + raise NotImplementedError( + "Derived classes must override get_cache_key if using extra inputs.") + return expr def get_function_definition_cache_key( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs - ) -> Hashable: - return (expr, *args, tuple(sorted(kwargs.items()))) + ) -> CacheKeyT: + if args or kwargs: + raise NotImplementedError( + "Derived classes must override get_function_definition_cache_key if " + "using extra inputs.") + return expr def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: key = self._cache.get_key(expr, *args, **kwargs) diff --git a/pytato/transform/einsum_distributive_law.py b/pytato/transform/einsum_distributive_law.py index 8cd635f61..694901b03 100644 --- a/pytato/transform/einsum_distributive_law.py +++ b/pytato/transform/einsum_distributive_law.py @@ -57,6 +57,8 @@ Stack, ) from pytato.transform import ( + ArrayOrNames, + CacheKeyT, MappedT, TransformMapperWithExtraArgs, _verify_is_array, @@ -160,6 +162,13 @@ def __init__(self, super().__init__() self.how_to_distribute = how_to_distribute + def get_cache_key( + self, + expr: ArrayOrNames, + ctx: _EinsumDistributiveLawMapperContext | None + ) -> CacheKeyT: + return (expr, ctx) + def _map_input_base(self, expr: InputArgumentBase, ctx: _EinsumDistributiveLawMapperContext | None, From 39866e400aaf4b6e12cb09370e58619ef417ce82 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 7 Feb 2025 16:16:55 -0600 Subject: [PATCH 133/178] add CacheInputs to simplify cache key handling logic --- pytato/analysis/__init__.py | 6 +- pytato/codegen.py | 4 +- pytato/distributed/partition.py | 10 +- pytato/transform/__init__.py | 164 +++++++++++++++++--------------- pytato/transform/metadata.py | 10 +- 5 files changed, 101 insertions(+), 93 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index e1487b710..880a15b7a 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -622,9 +622,9 @@ def combine(self, *args: int) -> int: return sum(args) def rec(self, expr: ArrayOrNames) -> int: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache.retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, @@ -639,7 +639,7 @@ def rec(self, expr: ArrayOrNames) -> int: else: result = 0 + s - self._cache.add(expr, 0, key=key) + self._cache.add(inputs, 0) return result diff --git a/pytato/codegen.py b/pytato/codegen.py index 86a328929..cb957f076 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -138,8 +138,8 @@ def __init__( self, target: Target, kernels_seen: dict[str, lp.LoopKernel] | None = None, - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.bound_arguments: dict[str, DataInterface] = {} diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 741e36548..a022f8f8e 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -240,9 +240,9 @@ def __init__(self, recvd_ary_to_name: Mapping[Array, str], sptpo_ary_to_name: Mapping[Array, str], name_to_output: Mapping[str, Array], - _cache: TransformMapperCache[ArrayOrNames] | None = None, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition] | None = None, + TransformMapperCache[FunctionDefinition, []] | None = None, ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) @@ -261,7 +261,7 @@ def clone_for_callee( return type(self)( {}, {}, {}, _function_cache=cast( - "TransformMapperCache[FunctionDefinition]", self._function_cache)) + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) def map_placeholder(self, expr: Placeholder) -> Placeholder: self.user_input_names.add(expr.name) @@ -294,9 +294,9 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend: return new_send def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache.retrieve(inputs) except KeyError: pass diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index c8f195d86..af73906bb 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -46,6 +46,7 @@ from typing_extensions import Self from pymbolic.mapper.optimize import optimize_mapper +from pytools import memoize_method from pytato.array import ( AbstractResultWithNamedArrays, @@ -93,6 +94,7 @@ __doc__ = """ .. autoclass:: Mapper +.. autoclass:: CacheInputs .. autoclass:: CachedMapperCache .. autoclass:: CachedMapper .. autoclass:: TransformMapperCache @@ -304,12 +306,45 @@ def __call__( CacheKeyT: TypeAlias = Hashable -class CachedMapperCache(Generic[CacheExprT, CacheResultT]): +class CacheInputs(Generic[CacheExprT, P]): + """ + Data structure for inputs to :class:`CachedMapperCache`. + + .. attribute:: expr + + The input expression being mapped. + + .. attribute:: key + + The cache key corresponding to *expr* and any additional inputs that were + passed. + + """ + def __init__( + self, + expr: CacheExprT, + key_func: Callable[..., CacheKeyT], + *args: P.args, + **kwargs: P.kwargs): + self.expr: CacheExprT = expr + self._args: tuple[Any, ...] = args + self._kwargs: dict[str, Any] = kwargs + self._key_func = key_func + + @memoize_method + def _get_key(self) -> CacheKeyT: + return self._key_func(self.expr, *self._args, **self._kwargs) + + @property + def key(self) -> CacheKeyT: + return self._get_key() + + +class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]): """ Cache for mappers. .. automethod:: __init__ - .. method:: get_key Compute the key for an input expression. @@ -317,37 +352,16 @@ class CachedMapperCache(Generic[CacheExprT, CacheResultT]): .. automethod:: retrieve .. automethod:: clear """ - def __init__( - self, - key_func: Callable[..., CacheKeyT]) -> None: - """ - Initialize the cache. - - :arg key_func: Function to compute a hashable cache key from an input - expression and any extra arguments. - """ - self.get_key = key_func - + def __init__(self) -> None: + """Initialize the cache.""" self._expr_key_to_result: dict[CacheKeyT, CacheResultT] = {} def add( self, - key_inputs: - CacheExprT - # Currently, Python's type system doesn't have a way to annotate - # containers of args/kwargs (ParamSpec won't work here). So we have - # to fall back to using Any. More details here: - # https://github.com/python/typing/issues/1252 - | tuple[CacheExprT, tuple[Any, ...], dict[str, Any]], - result: CacheResultT, - key: CacheKeyT | None = None) -> CacheResultT: + inputs: CacheInputs[CacheExprT, P], + result: CacheResultT) -> CacheResultT: """Cache a mapping result.""" - if key is None: - if isinstance(key_inputs, tuple): - expr, key_args, key_kwargs = key_inputs - key = self.get_key(expr, *key_args, **key_kwargs) - else: - key = self.get_key(key_inputs) + key = inputs.key assert key not in self._expr_key_to_result, \ f"Cache entry is already present for key '{key}'." @@ -356,20 +370,9 @@ def add( return result - def retrieve( - self, - key_inputs: - CacheExprT - | tuple[CacheExprT, tuple[Any, ...], dict[str, Any]], - key: CacheKeyT | None = None) -> CacheResultT: + def retrieve(self, inputs: CacheInputs[CacheExprT, P]) -> CacheResultT: """Retrieve the cached mapping result.""" - if key is None: - if isinstance(key_inputs, tuple): - expr, key_args, key_kwargs = key_inputs - key = self.get_key(expr, *key_args, **key_kwargs) - else: - key = self.get_key(key_inputs) - + key = inputs.key return self._expr_key_to_result[key] def clear(self) -> None: @@ -389,20 +392,20 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]): def __init__( self, _cache: - CachedMapperCache[ArrayOrNames, ResultT] | None = None, + CachedMapperCache[ArrayOrNames, ResultT, P] | None = None, _function_cache: - CachedMapperCache[FunctionDefinition, FunctionResultT] | None = None + CachedMapperCache[FunctionDefinition, FunctionResultT, P] | None = None ) -> None: super().__init__() - self._cache: CachedMapperCache[ArrayOrNames, ResultT] = ( + self._cache: CachedMapperCache[ArrayOrNames, ResultT, P] = ( _cache if _cache is not None - else CachedMapperCache(self.get_cache_key)) + else CachedMapperCache()) self._function_cache: CachedMapperCache[ - FunctionDefinition, FunctionResultT] = ( + FunctionDefinition, FunctionResultT, P] = ( _function_cache if _function_cache is not None - else CachedMapperCache(self.get_function_definition_cache_key)) + else CachedMapperCache()) def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs @@ -421,33 +424,39 @@ def get_function_definition_cache_key( "using extra inputs.") return expr + def _make_cache_inputs( + self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs + ) -> CacheInputs[ArrayOrNames, P]: + return CacheInputs(expr, self.get_cache_key, *args, **kwargs) + + def _make_function_definition_cache_inputs( + self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs + ) -> CacheInputs[FunctionDefinition, P]: + return CacheInputs( + expr, self.get_function_definition_cache_key, *args, **kwargs) + def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: - key = self._cache.get_key(expr, *args, **kwargs) + inputs = self._make_cache_inputs(expr, *args, **kwargs) try: - return self._cache.retrieve((expr, args, kwargs), key=key) + return self._cache.retrieve(inputs) except KeyError: - return self._cache.add( - (expr, args, kwargs), - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - Mapper.rec(self, expr, *args, **kwargs), - key=key) + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + return self._cache.add(inputs, Mapper.rec(self, expr, *args, **kwargs)) def rec_function_definition( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs ) -> FunctionResultT: - key = self._function_cache.get_key(expr, *args, **kwargs) + inputs = self._make_function_definition_cache_inputs(expr, *args, **kwargs) try: - return self._function_cache.retrieve((expr, args, kwargs), key=key) + return self._function_cache.retrieve(inputs) except KeyError: return self._function_cache.add( - (expr, args, kwargs), # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, # see https://github.com/inducer/pytato/pull/585 - Mapper.rec_function_definition(self, expr, *args, **kwargs), - key=key) + inputs, Mapper.rec_function_definition(self, expr, *args, **kwargs)) def clone_for_callee( self, function: FunctionDefinition) -> Self: @@ -463,7 +472,7 @@ def clone_for_callee( # {{{ TransformMapper -class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT]): +class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, P]): pass @@ -477,8 +486,8 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): """ def __init__( self, - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) @@ -499,9 +508,9 @@ class TransformMapperWithExtraArgs( """ def __init__( self, - _cache: TransformMapperCache[ArrayOrNames] | None = None, + _cache: TransformMapperCache[ArrayOrNames, P] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition] | None = None + TransformMapperCache[FunctionDefinition, P] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) @@ -1529,8 +1538,8 @@ class CachedMapAndCopyMapper(CopyMapper): def __init__( self, map_fn: Callable[[ArrayOrNames], ArrayOrNames], - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn @@ -1540,18 +1549,17 @@ def clone_for_callee( return type(self)( self.map_fn, _function_cache=cast( - "TransformMapperCache[FunctionDefinition]", self._function_cache)) + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache.retrieve(inputs) except KeyError: - return self._cache.add( - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - expr, Mapper.rec(self, self.map_fn(expr)), key=key) + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + return self._cache.add(inputs, Mapper.rec(self, self.map_fn(expr))) # }}} @@ -2076,8 +2084,8 @@ class DataWrapperDeduplicator(CopyMapper): """ def __init__( self, - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.data_wrapper_cache: dict[CacheKeyT, DataWrapper] = {} diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index d50da22e0..200aa25b4 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -416,9 +416,9 @@ class AxisTagAttacher(CopyMapper): def __init__(self, axis_to_tags: Mapping[tuple[Array, int | str], Collection[Tag]], tag_corresponding_redn_descr: bool, - _cache: TransformMapperCache[ArrayOrNames] | None = None, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition] | None = None): + TransformMapperCache[FunctionDefinition, []] | None = None): super().__init__(_cache=_cache, _function_cache=_function_cache) self.axis_to_tags: Mapping[tuple[Array, int | str], Collection[Tag]] = axis_to_tags @@ -465,9 +465,9 @@ def _attach_tags(self, expr: Array, rec_expr: Array) -> Array: return result def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache.retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, @@ -478,7 +478,7 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: assert isinstance(expr, Array) # type-ignore reason: passed "ArrayOrNames"; expected "Array" result = self._attach_tags(expr, result) # type: ignore[arg-type] - return self._cache.add(expr, result, key=key) + return self._cache.add(inputs, result) def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError( From 6f6ccbea12493a779cdf38dd1fa838b05c421fea Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 7 Feb 2025 09:25:55 -0600 Subject: [PATCH 134/178] rename expr_key* to input_key* --- pytato/transform/__init__.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index af73906bb..b5dd9f094 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -354,7 +354,7 @@ class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]): """ def __init__(self) -> None: """Initialize the cache.""" - self._expr_key_to_result: dict[CacheKeyT, CacheResultT] = {} + self._input_key_to_result: dict[CacheKeyT, CacheResultT] = {} def add( self, @@ -363,21 +363,20 @@ def add( """Cache a mapping result.""" key = inputs.key - assert key not in self._expr_key_to_result, \ + assert key not in self._input_key_to_result, \ f"Cache entry is already present for key '{key}'." - self._expr_key_to_result[key] = result - + self._input_key_to_result[key] = result return result def retrieve(self, inputs: CacheInputs[CacheExprT, P]) -> CacheResultT: """Retrieve the cached mapping result.""" key = inputs.key - return self._expr_key_to_result[key] + return self._input_key_to_result[key] def clear(self) -> None: """Reset the cache.""" - self._expr_key_to_result = {} + self._input_key_to_result = {} class CachedMapper(Mapper[ResultT, FunctionResultT, P]): From f7d5c7e4e1a12ed430940356a48ec063cd66a309 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 18 Feb 2025 14:22:24 -0600 Subject: [PATCH 135/178] refactor to avoid performance drop --- pytato/transform/__init__.py | 45 ++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index b5dd9f094..4d2bbd2bd 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -46,7 +46,6 @@ from typing_extensions import Self from pymbolic.mapper.optimize import optimize_mapper -from pytools import memoize_method from pytato.array import ( AbstractResultWithNamedArrays, @@ -94,7 +93,7 @@ __doc__ = """ .. autoclass:: Mapper -.. autoclass:: CacheInputs +.. autoclass:: CacheInputsWithKey .. autoclass:: CachedMapperCache .. autoclass:: CachedMapper .. autoclass:: TransformMapperCache @@ -306,7 +305,7 @@ def __call__( CacheKeyT: TypeAlias = Hashable -class CacheInputs(Generic[CacheExprT, P]): +class CacheInputsWithKey(Generic[CacheExprT, P]): """ Data structure for inputs to :class:`CachedMapperCache`. @@ -314,6 +313,14 @@ class CacheInputs(Generic[CacheExprT, P]): The input expression being mapped. + .. attribute:: args + + A :class:`tuple` of extra positional arguments. + + .. attribute:: kwargs + + A :class:`dict` of extra keyword arguments. + .. attribute:: key The cache key corresponding to *expr* and any additional inputs that were @@ -323,21 +330,13 @@ class CacheInputs(Generic[CacheExprT, P]): def __init__( self, expr: CacheExprT, - key_func: Callable[..., CacheKeyT], + key: CacheKeyT, *args: P.args, **kwargs: P.kwargs): self.expr: CacheExprT = expr - self._args: tuple[Any, ...] = args - self._kwargs: dict[str, Any] = kwargs - self._key_func = key_func - - @memoize_method - def _get_key(self) -> CacheKeyT: - return self._key_func(self.expr, *self._args, **self._kwargs) - - @property - def key(self) -> CacheKeyT: - return self._get_key() + self.args: tuple[Any, ...] = args + self.kwargs: dict[str, Any] = kwargs + self.key: CacheKeyT = key class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]): @@ -358,7 +357,7 @@ def __init__(self) -> None: def add( self, - inputs: CacheInputs[CacheExprT, P], + inputs: CacheInputsWithKey[CacheExprT, P], result: CacheResultT) -> CacheResultT: """Cache a mapping result.""" key = inputs.key @@ -369,7 +368,7 @@ def add( self._input_key_to_result[key] = result return result - def retrieve(self, inputs: CacheInputs[CacheExprT, P]) -> CacheResultT: + def retrieve(self, inputs: CacheInputsWithKey[CacheExprT, P]) -> CacheResultT: """Retrieve the cached mapping result.""" key = inputs.key return self._input_key_to_result[key] @@ -425,14 +424,16 @@ def get_function_definition_cache_key( def _make_cache_inputs( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs - ) -> CacheInputs[ArrayOrNames, P]: - return CacheInputs(expr, self.get_cache_key, *args, **kwargs) + ) -> CacheInputsWithKey[ArrayOrNames, P]: + return CacheInputsWithKey( + expr, self.get_cache_key(expr, *args, **kwargs), *args, **kwargs) def _make_function_definition_cache_inputs( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs - ) -> CacheInputs[FunctionDefinition, P]: - return CacheInputs( - expr, self.get_function_definition_cache_key, *args, **kwargs) + ) -> CacheInputsWithKey[FunctionDefinition, P]: + return CacheInputsWithKey( + expr, self.get_function_definition_cache_key(expr, *args, **kwargs), + *args, **kwargs) def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: inputs = self._make_cache_inputs(expr, *args, **kwargs) From 81d300a89fb66b1909301806a4fa316cc51f01c5 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 10 Mar 2025 16:55:49 -0500 Subject: [PATCH 136/178] add comment explaining why CachedMapper.get_cache_key and get_function_definition_cache_key are not defined for general extra args/kwargs --- pytato/transform/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 4d2bbd2bd..bc9a45da7 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -409,6 +409,9 @@ def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs ) -> CacheKeyT: if args or kwargs: + # Depending on whether extra arguments are passed by position or by + # keyword, they can end up in either args or kwargs; hence key is not + # uniquely defined in the general case raise NotImplementedError( "Derived classes must override get_cache_key if using extra inputs.") return expr @@ -417,6 +420,9 @@ def get_function_definition_cache_key( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs ) -> CacheKeyT: if args or kwargs: + # Depending on whether extra arguments are passed by position or by + # keyword, they can end up in either args or kwargs; hence key is not + # uniquely defined in the general case raise NotImplementedError( "Derived classes must override get_function_definition_cache_key if " "using extra inputs.") From 46e0ffa43de797bf920305cc6059c85f7a859da1 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 19 Sep 2024 19:31:13 -0500 Subject: [PATCH 137/178] add map_dict_of_named_arrays to DirectPredecessorsGetter --- pytato/analysis/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 880a15b7a..e461edf5e 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -337,6 +337,10 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]): def _get_preds_from_shape(self, shape: ShapeType) -> FrozenOrderedSet[ArrayOrNames]: return FrozenOrderedSet(dim for dim in shape if isinstance(dim, Array)) + def map_dict_of_named_arrays( + self, expr: DictOfNamedArrays) -> FrozenOrderedSet[ArrayOrNames]: + return FrozenOrderedSet(expr._data.values()) + def map_index_lambda(self, expr: IndexLambda) -> FrozenOrderedSet[ArrayOrNames]: return (FrozenOrderedSet(expr.bindings.values()) | self._get_preds_from_shape(expr.shape)) From ca6911555370dbd670052d78bc6d80963c51a105 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 24 Sep 2024 14:42:38 -0500 Subject: [PATCH 138/178] support functions as inputs and outputs in DirectPredecessorsGetter --- pytato/analysis/__init__.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index e461edf5e..65d7c7756 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -323,7 +323,11 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter -class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]): +class DirectPredecessorsGetter( + Mapper[ + FrozenOrderedSet[ArrayOrNames | FunctionDefinition], + FrozenOrderedSet[ArrayOrNames], + []]): """ Mapper to get the `direct predecessors @@ -334,6 +338,10 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]): We only consider the predecessors of a nodes in a data-flow sense. """ + def __init__(self, *, include_functions: bool = False) -> None: + super().__init__() + self.include_functions = include_functions + def _get_preds_from_shape(self, shape: ShapeType) -> FrozenOrderedSet[ArrayOrNames]: return FrozenOrderedSet(dim for dim in shape if isinstance(dim, Array)) @@ -401,8 +409,17 @@ def map_distributed_send_ref_holder(self, ) -> FrozenOrderedSet[ArrayOrNames]: return FrozenOrderedSet([expr.passthrough_data]) - def map_call(self, expr: Call) -> FrozenOrderedSet[ArrayOrNames]: - return FrozenOrderedSet(expr.bindings.values()) + def map_call( + self, expr: Call) -> FrozenOrderedSet[ArrayOrNames | FunctionDefinition]: + result: FrozenOrderedSet[ArrayOrNames | FunctionDefinition] = \ + FrozenOrderedSet(expr.bindings.values()) + if self.include_functions: + result = result | FrozenOrderedSet([expr.function]) + return result + + def map_function_definition( + self, expr: FunctionDefinition) -> FrozenOrderedSet[ArrayOrNames]: + return FrozenOrderedSet(expr.returns.values()) def map_named_call_result( self, expr: NamedCallResult) -> FrozenOrderedSet[ArrayOrNames]: From 180040dd2af76fb5ef48b158fda6a44093bbf3e6 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 29 Aug 2024 16:57:13 -0500 Subject: [PATCH 139/178] add collision/duplication checks to CachedMapper/TransformMapper/TransformMapperWithExtraArgs --- pytato/analysis/__init__.py | 4 +- pytato/distributed/partition.py | 2 +- pytato/transform/__init__.py | 270 ++++++++++++++++++++++++++++++-- pytato/transform/metadata.py | 4 +- 4 files changed, 258 insertions(+), 22 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 65d7c7756..019fbd846 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -645,7 +645,7 @@ def combine(self, *args: int) -> int: def rec(self, expr: ArrayOrNames) -> int: inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, @@ -660,7 +660,7 @@ def rec(self, expr: ArrayOrNames) -> int: else: result = 0 + s - self._cache.add(inputs, 0) + self._cache_add(inputs, 0) return result diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index a022f8f8e..73eec2745 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -296,7 +296,7 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend: def rec(self, expr: ArrayOrNames) -> ArrayOrNames: inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: pass diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index bc9a45da7..123d2b858 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -188,6 +188,14 @@ class ForeignObjectError(ValueError): pass +class CacheCollisionError(ValueError): + pass + + +class CacheNoOpDuplicationError(ValueError): + pass + + # {{{ mapper base class ResultT = TypeVar("ResultT") @@ -300,7 +308,7 @@ def __call__( # {{{ CachedMapper -CacheExprT = TypeVar("CacheExprT") +CacheExprT = TypeVar("CacheExprT", ArrayOrNames, FunctionDefinition) CacheResultT = TypeVar("CacheResultT") CacheKeyT: TypeAlias = Hashable @@ -351,9 +359,18 @@ class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]): .. automethod:: retrieve .. automethod:: clear """ - def __init__(self) -> None: - """Initialize the cache.""" + def __init__(self, err_on_collision: bool) -> None: + """ + Initialize the cache. + + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. + """ + self.err_on_collision = err_on_collision + self._input_key_to_result: dict[CacheKeyT, CacheResultT] = {} + if self.err_on_collision: + self._input_key_to_expr: dict[CacheKeyT, CacheExprT] = {} def add( self, @@ -366,16 +383,27 @@ def add( f"Cache entry is already present for key '{key}'." self._input_key_to_result[key] = result + if self.err_on_collision: + self._input_key_to_expr[key] = inputs.expr + return result def retrieve(self, inputs: CacheInputsWithKey[CacheExprT, P]) -> CacheResultT: """Retrieve the cached mapping result.""" key = inputs.key - return self._input_key_to_result[key] + + result = self._input_key_to_result[key] + + if self.err_on_collision and inputs.expr is not self._input_key_to_expr[key]: + raise CacheCollisionError + + return result def clear(self) -> None: """Reset the cache.""" self._input_key_to_result = {} + if self.err_on_collision: + self._input_key_to_expr = {} class CachedMapper(Mapper[ResultT, FunctionResultT, P]): @@ -389,6 +417,7 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]): """ def __init__( self, + err_on_collision: bool = False, _cache: CachedMapperCache[ArrayOrNames, ResultT, P] | None = None, _function_cache: @@ -398,12 +427,12 @@ def __init__( self._cache: CachedMapperCache[ArrayOrNames, ResultT, P] = ( _cache if _cache is not None - else CachedMapperCache()) + else CachedMapperCache(err_on_collision=err_on_collision)) self._function_cache: CachedMapperCache[ FunctionDefinition, FunctionResultT, P] = ( _function_cache if _function_cache is not None - else CachedMapperCache()) + else CachedMapperCache(err_on_collision=err_on_collision)) def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs @@ -441,24 +470,53 @@ def _make_function_definition_cache_inputs( expr, self.get_function_definition_cache_key(expr, *args, **kwargs), *args, **kwargs) + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, P], + result: ResultT) -> ResultT: + return self._cache.add(inputs, result) + + def _function_cache_add( + self, + inputs: CacheInputsWithKey[FunctionDefinition, P], + result: FunctionResultT) -> FunctionResultT: + return self._function_cache.add(inputs, result) + + def _cache_retrieve(self, inputs: CacheInputsWithKey[ArrayOrNames, P]) -> ResultT: + try: + return self._cache.retrieve(inputs) + except CacheCollisionError as e: + raise ValueError( + f"cache collision detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_retrieve( + self, inputs: CacheInputsWithKey[FunctionDefinition, P]) -> FunctionResultT: + try: + return self._function_cache.retrieve(inputs) + except CacheCollisionError as e: + raise ValueError( + f"cache collision detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: inputs = self._make_cache_inputs(expr, *args, **kwargs) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, # see https://github.com/inducer/pytato/pull/585 - return self._cache.add(inputs, Mapper.rec(self, expr, *args, **kwargs)) + return self._cache_add(inputs, Mapper.rec(self, expr, *args, **kwargs)) def rec_function_definition( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs ) -> FunctionResultT: inputs = self._make_function_definition_cache_inputs(expr, *args, **kwargs) try: - return self._function_cache.retrieve(inputs) + return self._function_cache_retrieve(inputs) except KeyError: - return self._function_cache.add( + return self._function_cache_add( # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, # see https://github.com/inducer/pytato/pull/585 @@ -470,8 +528,10 @@ def clone_for_callee( Called to clone *self* before starting traversal of a :class:`pytato.function.FunctionDefinition`. """ - # Functions are cached globally, but arrays aren't - return type(self)(_function_cache=self._function_cache) + return type(self)( + err_on_collision=self._cache.err_on_collision, + # Functions are cached globally, but arrays aren't + _function_cache=self._function_cache) # }}} @@ -479,7 +539,67 @@ def clone_for_callee( # {{{ TransformMapper class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, P]): - pass + """ + Cache for :class:`TransformMapper` and :class:`TransformMapperWithExtraArgs`. + + .. automethod:: __init__ + .. automethod:: add + """ + def __init__( + self, + err_on_collision: bool, + err_on_no_op_duplication: bool) -> None: + """ + Initialize the cache. + + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. + :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + super().__init__(err_on_collision=err_on_collision) + + self.err_on_no_op_duplication = err_on_no_op_duplication + + def add( + self, + inputs: CacheInputsWithKey[CacheExprT, P], + result: CacheExprT) -> CacheExprT: + """ + Cache a mapping result. + + Returns the cached result (which may not be identical to *result* if a + result was already cached with the same result key). + """ + key = inputs.key + + assert key not in self._input_key_to_result, \ + f"Cache entry is already present for key '{key}'." + + if self.err_on_no_op_duplication: + from pytato.analysis import DirectPredecessorsGetter + pred_getter = DirectPredecessorsGetter(include_functions=True) + if ( + hash(result) == hash(inputs.expr) + and result == inputs.expr + and result is not inputs.expr + # Need this check in order to handle input DAGs that have existing + # duplicates. Deduplication will potentially replace predecessors + # of `expr` with cached versions, producing a new `result` that has + # the same cache key as `expr`. + and all( + result_pred is pred + for pred, result_pred in zip( + pred_getter(inputs.expr), + pred_getter(result), + strict=True))): + raise CacheNoOpDuplicationError from None + + self._input_key_to_result[key] = result + if self.err_on_collision: + self._input_key_to_expr[key] = inputs.expr + + return result class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): @@ -489,13 +609,71 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): Enables certain operations that can only be done if the mapping results are also arrays (e.g., computing a cache key from them). Does not implement default mapper methods; for that, see :class:`CopyMapper`. + + .. automethod:: __init__ + .. automethod:: clone_for_callee """ def __init__( self, + err_on_collision: bool = False, + err_on_no_op_duplication: bool = False, _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) + """ + :arg err_on_collision: Raise an exception if two distinct input array + instances have the same key. + :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + if _cache is None: + _cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_no_op_duplication=err_on_no_op_duplication) + + if _function_cache is None: + _function_cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_no_op_duplication=err_on_no_op_duplication) + + super().__init__( + err_on_collision=err_on_collision, + _cache=_cache, + _function_cache=_function_cache) + + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, []], + result: ArrayOrNames) -> ArrayOrNames: + try: + return self._cache.add(inputs, result) + except CacheNoOpDuplicationError as e: + raise ValueError( + f"no-op duplication detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_add( + self, + inputs: CacheInputsWithKey[FunctionDefinition, []], + result: FunctionDefinition) -> FunctionDefinition: + try: + return self._function_cache.add(inputs, result) + except CacheNoOpDuplicationError as e: + raise ValueError( + f"no-op duplication detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + function_cache = cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache) + return type(self)( + err_on_collision=function_cache.err_on_collision, + err_on_no_op_duplication=function_cache.err_on_no_op_duplication, + _function_cache=function_cache) # }}} @@ -511,14 +689,72 @@ class TransformMapperWithExtraArgs( The logic in :class:`TransformMapper` purposely does not take the extra arguments to keep the cost of its each call frame low. + + .. automethod:: __init__ + .. automethod:: clone_for_callee """ def __init__( self, + err_on_collision: bool = False, + err_on_no_op_duplication: bool = False, _cache: TransformMapperCache[ArrayOrNames, P] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, P] | None = None ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) + """ + :arg err_on_collision: Raise an exception if two distinct input array + instances have the same key. + :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + if _cache is None: + _cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_no_op_duplication=err_on_no_op_duplication) + + if _function_cache is None: + _function_cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_no_op_duplication=err_on_no_op_duplication) + + super().__init__( + err_on_collision=err_on_collision, + _cache=_cache, + _function_cache=_function_cache) + + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, P], + result: ArrayOrNames) -> ArrayOrNames: + try: + return self._cache.add(inputs, result) + except CacheNoOpDuplicationError as e: + raise ValueError( + f"no-op duplication detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_add( + self, + inputs: CacheInputsWithKey[FunctionDefinition, P], + result: FunctionDefinition) -> FunctionDefinition: + try: + return self._function_cache.add(inputs, result) + except CacheNoOpDuplicationError as e: + raise ValueError( + f"no-op duplication detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + function_cache = cast( + "TransformMapperCache[FunctionDefinition, P]", self._function_cache) + return type(self)( + err_on_collision=function_cache.err_on_collision, + err_on_no_op_duplication=function_cache.err_on_no_op_duplication, + _function_cache=function_cache) # }}} @@ -1560,12 +1796,12 @@ def clone_for_callee( def rec(self, expr: ArrayOrNames) -> ArrayOrNames: inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, # see https://github.com/inducer/pytato/pull/585 - return self._cache.add(inputs, Mapper.rec(self, self.map_fn(expr))) + return self._cache_add(inputs, Mapper.rec(self, self.map_fn(expr))) # }}} diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 200aa25b4..de3625978 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -467,7 +467,7 @@ def _attach_tags(self, expr: Array, rec_expr: Array) -> Array: def rec(self, expr: ArrayOrNames) -> ArrayOrNames: inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, @@ -478,7 +478,7 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: assert isinstance(expr, Array) # type-ignore reason: passed "ArrayOrNames"; expected "Array" result = self._attach_tags(expr, result) # type: ignore[arg-type] - return self._cache.add(inputs, result) + return self._cache_add(inputs, result) def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError( From 3982b18686b4ccc18adc42698ee1e35421dcb901 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 18 Feb 2025 14:41:56 -0600 Subject: [PATCH 140/178] fix doc --- pytato/transform/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 123d2b858..9c093dc66 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -568,8 +568,7 @@ def add( """ Cache a mapping result. - Returns the cached result (which may not be identical to *result* if a - result was already cached with the same result key). + Returns *result*. """ key = inputs.key From d2ef3a869044b029c0a69f98fbb962a6f4f0669d Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 18 Feb 2025 14:49:23 -0600 Subject: [PATCH 141/178] change terminology from 'no-op duplication' to 'mapper-created duplicate' --- pytato/transform/__init__.py | 48 ++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 9c093dc66..25e314089 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -192,7 +192,7 @@ class CacheCollisionError(ValueError): pass -class CacheNoOpDuplicationError(ValueError): +class MapperCreatedDuplicateError(ValueError): pass @@ -548,18 +548,18 @@ class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, P]): def __init__( self, err_on_collision: bool, - err_on_no_op_duplication: bool) -> None: + err_on_created_duplicate: bool) -> None: """ Initialize the cache. :arg err_on_collision: Raise an exception if two distinct input expression instances have the same key. - :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + :arg err_on_created_duplicate: Raise an exception if mapping produces a new array instance that has the same key as the input array. """ super().__init__(err_on_collision=err_on_collision) - self.err_on_no_op_duplication = err_on_no_op_duplication + self.err_on_created_duplicate = err_on_created_duplicate def add( self, @@ -575,7 +575,7 @@ def add( assert key not in self._input_key_to_result, \ f"Cache entry is already present for key '{key}'." - if self.err_on_no_op_duplication: + if self.err_on_created_duplicate: from pytato.analysis import DirectPredecessorsGetter pred_getter = DirectPredecessorsGetter(include_functions=True) if ( @@ -592,7 +592,7 @@ def add( pred_getter(inputs.expr), pred_getter(result), strict=True))): - raise CacheNoOpDuplicationError from None + raise MapperCreatedDuplicateError from None self._input_key_to_result[key] = result if self.err_on_collision: @@ -615,25 +615,25 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): def __init__( self, err_on_collision: bool = False, - err_on_no_op_duplication: bool = False, + err_on_created_duplicate: bool = False, _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: """ :arg err_on_collision: Raise an exception if two distinct input array instances have the same key. - :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + :arg err_on_created_duplicate: Raise an exception if mapping produces a new array instance that has the same key as the input array. """ if _cache is None: _cache = TransformMapperCache( err_on_collision=err_on_collision, - err_on_no_op_duplication=err_on_no_op_duplication) + err_on_created_duplicate=err_on_created_duplicate) if _function_cache is None: _function_cache = TransformMapperCache( err_on_collision=err_on_collision, - err_on_no_op_duplication=err_on_no_op_duplication) + err_on_created_duplicate=err_on_created_duplicate) super().__init__( err_on_collision=err_on_collision, @@ -646,9 +646,9 @@ def _cache_add( result: ArrayOrNames) -> ArrayOrNames: try: return self._cache.add(inputs, result) - except CacheNoOpDuplicationError as e: + except MapperCreatedDuplicateError as e: raise ValueError( - f"no-op duplication detected on {type(inputs.expr)} in " + f"mapper-created duplicate detected on {type(inputs.expr)} in " f"{type(self)}.") from e def _function_cache_add( @@ -657,9 +657,9 @@ def _function_cache_add( result: FunctionDefinition) -> FunctionDefinition: try: return self._function_cache.add(inputs, result) - except CacheNoOpDuplicationError as e: + except MapperCreatedDuplicateError as e: raise ValueError( - f"no-op duplication detected on {type(inputs.expr)} in " + f"mapper-created duplicate detected on {type(inputs.expr)} in " f"{type(self)}.") from e def clone_for_callee(self, function: FunctionDefinition) -> Self: @@ -671,7 +671,7 @@ def clone_for_callee(self, function: FunctionDefinition) -> Self: "TransformMapperCache[FunctionDefinition, []]", self._function_cache) return type(self)( err_on_collision=function_cache.err_on_collision, - err_on_no_op_duplication=function_cache.err_on_no_op_duplication, + err_on_created_duplicate=function_cache.err_on_created_duplicate, _function_cache=function_cache) # }}} @@ -695,7 +695,7 @@ class TransformMapperWithExtraArgs( def __init__( self, err_on_collision: bool = False, - err_on_no_op_duplication: bool = False, + err_on_created_duplicate: bool = False, _cache: TransformMapperCache[ArrayOrNames, P] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, P] | None = None @@ -703,18 +703,18 @@ def __init__( """ :arg err_on_collision: Raise an exception if two distinct input array instances have the same key. - :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + :arg err_on_created_duplicate: Raise an exception if mapping produces a new array instance that has the same key as the input array. """ if _cache is None: _cache = TransformMapperCache( err_on_collision=err_on_collision, - err_on_no_op_duplication=err_on_no_op_duplication) + err_on_created_duplicate=err_on_created_duplicate) if _function_cache is None: _function_cache = TransformMapperCache( err_on_collision=err_on_collision, - err_on_no_op_duplication=err_on_no_op_duplication) + err_on_created_duplicate=err_on_created_duplicate) super().__init__( err_on_collision=err_on_collision, @@ -727,9 +727,9 @@ def _cache_add( result: ArrayOrNames) -> ArrayOrNames: try: return self._cache.add(inputs, result) - except CacheNoOpDuplicationError as e: + except MapperCreatedDuplicateError as e: raise ValueError( - f"no-op duplication detected on {type(inputs.expr)} in " + f"mapper-created duplicate detected on {type(inputs.expr)} in " f"{type(self)}.") from e def _function_cache_add( @@ -738,9 +738,9 @@ def _function_cache_add( result: FunctionDefinition) -> FunctionDefinition: try: return self._function_cache.add(inputs, result) - except CacheNoOpDuplicationError as e: + except MapperCreatedDuplicateError as e: raise ValueError( - f"no-op duplication detected on {type(inputs.expr)} in " + f"mapper-created duplicate detected on {type(inputs.expr)} in " f"{type(self)}.") from e def clone_for_callee(self, function: FunctionDefinition) -> Self: @@ -752,7 +752,7 @@ def clone_for_callee(self, function: FunctionDefinition) -> Self: "TransformMapperCache[FunctionDefinition, P]", self._function_cache) return type(self)( err_on_collision=function_cache.err_on_collision, - err_on_no_op_duplication=function_cache.err_on_no_op_duplication, + err_on_created_duplicate=function_cache.err_on_created_duplicate, _function_cache=function_cache) # }}} From 20d7a55b6211c1b09cdfadf045acbf493e6d2895 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 19 Feb 2025 13:04:20 -0600 Subject: [PATCH 142/178] reword explanation of predecessor check in duplication check --- pytato/transform/__init__.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 25e314089..512d8711b 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -582,10 +582,14 @@ def add( hash(result) == hash(inputs.expr) and result == inputs.expr and result is not inputs.expr - # Need this check in order to handle input DAGs that have existing - # duplicates. Deduplication will potentially replace predecessors - # of `expr` with cached versions, producing a new `result` that has - # the same cache key as `expr`. + # Only consider "direct" duplication, not duplication resulting + # from equality-preserving changes to predecessors. Assume that + # such changes are OK, otherwise they would have been detected + # at the point at which they originated. (For example, consider + # a DAG containing pre-existing duplicates. If a subexpression + # of *expr* is a duplicate and is replaced with a previously + # encountered version from the cache, a new instance of *expr* + # must be created. This should not trigger an error.) and all( result_pred is pred for pred, result_pred in zip( From 99b6e56ed468becc4256eced0ebf7d60505f4d91 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 26 Feb 2025 15:25:19 -0600 Subject: [PATCH 143/178] change CacheExprT constraint to use bound= apparently TypeVar(..., ) doesn't include subclasses of --- pytato/transform/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 512d8711b..085b7b45f 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -308,7 +308,7 @@ def __call__( # {{{ CachedMapper -CacheExprT = TypeVar("CacheExprT", ArrayOrNames, FunctionDefinition) +CacheExprT = TypeVar("CacheExprT", bound=ArrayOrNames | FunctionDefinition) CacheResultT = TypeVar("CacheResultT") CacheKeyT: TypeAlias = Hashable @@ -593,8 +593,10 @@ def add( and all( result_pred is pred for pred, result_pred in zip( - pred_getter(inputs.expr), - pred_getter(result), + # type-ignore-reason: mypy doesn't seem to recognize + # overloaded Mapper.__call__ here + pred_getter(inputs.expr), # type: ignore[arg-type] + pred_getter(result), # type: ignore[arg-type] strict=True))): raise MapperCreatedDuplicateError from None From 111fcd4c2080af03d531070277e55a1eabdbaab4 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 11 Mar 2025 17:20:48 -0500 Subject: [PATCH 144/178] add a couple of missing clone_for_callee definitions --- pytato/transform/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 085b7b45f..6598b9bcb 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1373,6 +1373,9 @@ def map_call(self, expr: Call) -> R: def map_named_call_result(self, expr: NamedCallResult) -> R: return self.rec(expr._container) + def clone_for_callee(self, function: FunctionDefinition) -> Self: + raise AssertionError("Control shouldn't reach this point.") + # }}} @@ -2386,6 +2389,11 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: self.data_wrapper_cache[cache_key] = expr return expr + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + _function_cache=cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) + def deduplicate_data_wrappers(array_or_names: ArrayOrNames) -> ArrayOrNames: """For the expression graph given as *array_or_names*, replace all From 2e820d338a44c5fb373c42674c83ffe7d5170de1 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 23 Sep 2024 20:45:47 -0500 Subject: [PATCH 145/178] add result deduplication to transform mappers --- pytato/transform/__init__.py | 59 +++++++++++++++++++++--------------- test/test_codegen.py | 3 +- 2 files changed, 36 insertions(+), 26 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 6598b9bcb..dfd710284 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -561,6 +561,8 @@ def __init__( self.err_on_created_duplicate = err_on_created_duplicate + self._result_to_cached_result: dict[CacheExprT, CacheExprT] = {} + def add( self, inputs: CacheInputsWithKey[CacheExprT, P], @@ -568,37 +570,44 @@ def add( """ Cache a mapping result. - Returns *result*. + Returns the cached result (which may not be identical to *result* if a + result was already cached with the same result key). """ key = inputs.key assert key not in self._input_key_to_result, \ f"Cache entry is already present for key '{key}'." - if self.err_on_created_duplicate: - from pytato.analysis import DirectPredecessorsGetter - pred_getter = DirectPredecessorsGetter(include_functions=True) - if ( - hash(result) == hash(inputs.expr) - and result == inputs.expr - and result is not inputs.expr - # Only consider "direct" duplication, not duplication resulting - # from equality-preserving changes to predecessors. Assume that - # such changes are OK, otherwise they would have been detected - # at the point at which they originated. (For example, consider - # a DAG containing pre-existing duplicates. If a subexpression - # of *expr* is a duplicate and is replaced with a previously - # encountered version from the cache, a new instance of *expr* - # must be created. This should not trigger an error.) - and all( - result_pred is pred - for pred, result_pred in zip( - # type-ignore-reason: mypy doesn't seem to recognize - # overloaded Mapper.__call__ here - pred_getter(inputs.expr), # type: ignore[arg-type] - pred_getter(result), # type: ignore[arg-type] - strict=True))): - raise MapperCreatedDuplicateError from None + try: + result = self._result_to_cached_result[result] + except KeyError: + if self.err_on_created_duplicate: + from pytato.analysis import DirectPredecessorsGetter + pred_getter = DirectPredecessorsGetter(include_functions=True) + if ( + hash(result) == hash(inputs.expr) + and result == inputs.expr + and result is not inputs.expr + # Only consider "direct" duplication, not duplication + # resulting from equality-preserving changes to predecessors. + # Assume that such changes are OK, otherwise they would have + # been detected at the point at which they originated. (For + # example, consider a DAG containing pre-existing duplicates. + # If a subexpression of *expr* is a duplicate and is replaced + # with a previously encountered version from the cache, a + # new instance of *expr* must be created. This should not + # trigger an error.) + and all( + result_pred is pred + for pred, result_pred in zip( + # type-ignore-reason: mypy doesn't seem to recognize + # overloaded Mapper.__call__ here + pred_getter(inputs.expr), # type: ignore[arg-type] + pred_getter(result), # type: ignore[arg-type] + strict=True))): + raise MapperCreatedDuplicateError from None + + self._result_to_cached_result[result] = result self._input_key_to_result[key] = result if self.err_on_collision: diff --git a/test/test_codegen.py b/test/test_codegen.py index 0c6972cf6..2193b7fc9 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1621,7 +1621,8 @@ def test_zero_size_cl_array_dedup(ctx_factory): dedup_dw_out, count_duplicates=True) # 'x2' would be merged with 'x1' as both of them point to the same data # 'x3' would be merged with 'x4' as both of them point to the same data - assert num_nodes_new == (num_nodes_old - 2) + # '2*x2' would be merged with '2*x1' as they are identical expressions + assert num_nodes_new == (num_nodes_old - 3) # {{{ test_deterministic_codegen From 42429a548cad731ba69224cb00e9072049d2f0bf Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 4 Sep 2024 21:35:04 -0500 Subject: [PATCH 146/178] add FIXME --- pytato/transform/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index dfd710284..22eeebef0 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1796,6 +1796,7 @@ class CachedMapAndCopyMapper(CopyMapper): def __init__( self, + # FIXME: Should map_fn be applied to functions too? map_fn: Callable[[ArrayOrNames], ArrayOrNames], _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None From ece90d02ee8e4544dcbff0afee49988ebc236b2e Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 10 Jun 2024 13:32:42 -0500 Subject: [PATCH 147/178] avoid unnecessary duplication in CopyMapper/CopyMapperWithExtraArgs --- pytato/transform/__init__.py | 622 +++++++++++++++++++++++------------ 1 file changed, 403 insertions(+), 219 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 22eeebef0..f60d2e3a8 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -788,63 +788,104 @@ def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...] ) -> tuple[IndexOrShapeExpr, ...]: # type-ignore-reason: apparently mypy cannot substitute typevars # here. - return tuple(self.rec(s) if isinstance(s, Array) else s # type: ignore[misc] - for s in situp) + new_situp = tuple( + self.rec(s) if isinstance(s, Array) else s + for s in situp) + if all(new_s is s for s, new_s in zip(situp, new_situp, strict=True)): + return situp + else: + return new_situp # type: ignore[return-value] def map_index_lambda(self, expr: IndexLambda) -> Array: - bindings: Mapping[str, Array] = immutabledict({ + new_shape = self.rec_idx_or_size_tuple(expr.shape) + new_bindings: Mapping[str, Array] = immutabledict({ name: self.rec(subexpr) for name, subexpr in sorted(expr.bindings.items())}) - return IndexLambda(expr=expr.expr, - shape=self.rec_idx_or_size_tuple(expr.shape), - dtype=expr.dtype, - bindings=bindings, - axes=expr.axes, - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if ( + new_shape is expr.shape + and frozenset(new_bindings.keys()) == frozenset(expr.bindings.keys()) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings)): + return expr + else: + return IndexLambda(expr=expr.expr, + shape=new_shape, + dtype=expr.dtype, + bindings=new_bindings, + axes=expr.axes, + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_placeholder(self, expr: Placeholder) -> Array: assert expr.name is not None - return Placeholder(name=expr.name, - shape=self.rec_idx_or_size_tuple(expr.shape), - dtype=expr.dtype, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape) + if new_shape is expr.shape: + return expr + else: + return Placeholder(name=expr.name, + shape=new_shape, + dtype=expr.dtype, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack) -> Array: - arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) - return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + return expr + else: + return Stack(arrays=new_arrays, axis=expr.axis, axes=expr.axes, + tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_concatenate(self, expr: Concatenate) -> Array: - arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) - return Concatenate(arrays=arrays, axis=expr.axis, - axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + return expr + else: + return Concatenate(arrays=new_arrays, axis=expr.axis, + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_roll(self, expr: Roll) -> Array: - return Roll(array=_verify_is_array(self.rec(expr.array)), - shift=expr.shift, - axis=expr.axis, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array)) + if new_ary is expr.array: + return expr + else: + return Roll(array=new_ary, + shift=expr.shift, + axis=expr.axis, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation) -> Array: - return AxisPermutation(array=_verify_is_array(self.rec(expr.array)), - axis_permutation=expr.axis_permutation, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array)) + if new_ary is expr.array: + return expr + else: + return AxisPermutation(array=new_ary, + axis_permutation=expr.axis_permutation, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase) -> Array: - return type(expr)(_verify_is_array(self.rec(expr.array)), - indices=self.rec_idx_or_size_tuple(expr.indices), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array)) + new_indices = self.rec_idx_or_size_tuple(expr.indices) + if new_ary is expr.array and new_indices is expr.indices: + return expr + else: + return type(expr)(new_ary, + indices=new_indices, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_basic_index(self, expr: BasicIndex) -> Array: return self._map_index_base(expr) @@ -860,91 +901,131 @@ def map_non_contiguous_advanced_index(self, return self._map_index_base(expr) def map_data_wrapper(self, expr: DataWrapper) -> Array: - return DataWrapper( - data=expr.data, - shape=self.rec_idx_or_size_tuple(expr.shape), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape) + if new_shape is expr.shape: + return expr + else: + return DataWrapper( + data=expr.data, + shape=new_shape, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_size_param(self, expr: SizeParam) -> Array: assert expr.name is not None - return SizeParam( - name=expr.name, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + return expr def map_einsum(self, expr: Einsum) -> Array: - return Einsum(expr.access_descriptors, - tuple(_verify_is_array(self.rec(arg)) for arg in expr.args), - axes=expr.axes, - redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - def map_named_array(self, expr: NamedArray) -> Array: - container = self.rec(expr._container) - assert isinstance(container, AbstractResultWithNamedArrays) - return type(expr)(container, - expr.name, + new_args = tuple(_verify_is_array(self.rec(arg)) for arg in expr.args) + if all( + new_arg is arg + for arg, new_arg in zip(expr.args, new_args, strict=True)): + return expr + else: + return Einsum(expr.access_descriptors, + new_args, axes=expr.axes, + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, tags=expr.tags, non_equality_tags=expr.non_equality_tags) + def map_named_array(self, expr: NamedArray) -> Array: + new_container = self.rec(expr._container) + assert isinstance(new_container, AbstractResultWithNamedArrays) + if new_container is expr._container: + return expr + else: + return type(expr)(new_container, + expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> DictOfNamedArrays: - return DictOfNamedArrays({key: _verify_is_array(self.rec(val.expr)) - for key, val in expr.items()}, - tags=expr.tags - ) + new_data = { + key: _verify_is_array(self.rec(val.expr)) + for key, val in expr.items()} + if all( + new_data_val is val.expr + for val, new_data_val in zip( + expr.values(), + new_data.values(), + strict=True)): + return expr + else: + return DictOfNamedArrays(new_data, tags=expr.tags) def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: - bindings: Mapping[Any, Any] = immutabledict( + new_bindings: Mapping[Any, Any] = immutabledict( {name: (self.rec(subexpr) if isinstance(subexpr, Array) else subexpr) for name, subexpr in sorted(expr.bindings.items())}) - - return LoopyCall(translation_unit=expr.translation_unit, - bindings=bindings, - entrypoint=expr.entrypoint, - tags=expr.tags, - ) + if ( + frozenset(new_bindings.keys()) == frozenset(expr.bindings.keys()) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings)): + return expr + else: + return LoopyCall(translation_unit=expr.translation_unit, + bindings=new_bindings, + entrypoint=expr.entrypoint, + tags=expr.tags, + ) def map_loopy_call_result(self, expr: LoopyCallResult) -> Array: - rec_container = self.rec(expr._container) - assert isinstance(rec_container, LoopyCall) - return LoopyCallResult( - _container=rec_container, - name=expr.name, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_container = self.rec(expr._container) + assert isinstance(new_container, LoopyCall) + if new_container is expr._container: + return expr + else: + return LoopyCallResult( + _container=new_container, + name=expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_reshape(self, expr: Reshape) -> Array: - return Reshape(_verify_is_array(self.rec(expr.array)), - newshape=self.rec_idx_or_size_tuple(expr.newshape), - order=expr.order, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array)) + new_newshape = self.rec_idx_or_size_tuple(expr.newshape) + if new_ary is expr.array and new_newshape is expr.newshape: + return expr + else: + return Reshape(new_ary, + newshape=new_newshape, + order=expr.order, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_distributed_send_ref_holder( self, expr: DistributedSendRefHolder) -> Array: - return DistributedSendRefHolder( - send=DistributedSend( - data=_verify_is_array(self.rec(expr.send.data)), - dest_rank=expr.send.dest_rank, - comm_tag=expr.send.comm_tag), - passthrough_data=_verify_is_array(self.rec(expr.passthrough_data)), - ) + new_send_data = _verify_is_array(self.rec(expr.send.data)) + if new_send_data is expr.send.data: + new_send = expr.send + else: + new_send = DistributedSend( + data=new_send_data, + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag) + new_passthrough = _verify_is_array(self.rec(expr.passthrough_data)) + if new_send is expr.send and new_passthrough is expr.passthrough_data: + return expr + else: + return DistributedSendRefHolder(new_send, new_passthrough) def map_distributed_recv(self, expr: DistributedRecv) -> Array: - return DistributedRecv( - src_rank=expr.src_rank, comm_tag=expr.comm_tag, - shape=self.rec_idx_or_size_tuple(expr.shape), - dtype=expr.dtype, tags=expr.tags, axes=expr.axes, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape) + if new_shape is expr.shape: + return expr + else: + return DistributedRecv( + src_rank=expr.src_rank, comm_tag=expr.comm_tag, + shape=new_shape, dtype=expr.dtype, tags=expr.tags, + axes=expr.axes, non_equality_tags=expr.non_equality_tags) def map_function_definition(self, expr: FunctionDefinition) -> FunctionDefinition: @@ -953,19 +1034,37 @@ def map_function_definition(self, new_mapper = self.clone_for_callee(expr) new_returns = {name: new_mapper(ret) for name, ret in expr.returns.items()} - return dataclasses.replace(expr, returns=immutabledict(new_returns)) + if all( + new_ret is ret + for ret, new_ret in zip( + expr.returns.values(), + new_returns.values(), + strict=True)): + return expr + else: + return dataclasses.replace(expr, returns=immutabledict(new_returns)) def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: - return Call(self.rec_function_definition(expr.function), - immutabledict({name: self.rec(bnd) - for name, bnd in expr.bindings.items()}), - tags=expr.tags, - ) + new_function = self.rec_function_definition(expr.function) + new_bindings = { + name: _verify_is_array(self.rec(bnd)) + for name, bnd in expr.bindings.items()} + if ( + new_function is expr.function + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values(), + strict=True))): + return expr + else: + return Call(new_function, immutabledict(new_bindings), tags=expr.tags) def map_named_call_result(self, expr: NamedCallResult) -> Array: - call = self.rec(expr._container) - assert isinstance(call, Call) - return call[expr.name] + new_call = self.rec(expr._container) + assert isinstance(new_call, Call) + return new_call[expr.name] class CopyMapperWithExtraArgs(TransformMapperWithExtraArgs[P]): @@ -989,70 +1088,102 @@ def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...], def map_index_lambda(self, expr: IndexLambda, *args: P.args, **kwargs: P.kwargs) -> Array: - bindings: Mapping[str, Array] = immutabledict({ + new_shape = self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs) + new_bindings: Mapping[str, Array] = immutabledict({ name: self.rec(subexpr, *args, **kwargs) for name, subexpr in sorted(expr.bindings.items())}) - return IndexLambda(expr=expr.expr, - shape=self.rec_idx_or_size_tuple(expr.shape, - *args, **kwargs), - dtype=expr.dtype, - bindings=bindings, - axes=expr.axes, - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if ( + new_shape is expr.shape + and frozenset(new_bindings.keys()) == frozenset(expr.bindings.keys()) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings)): + return expr + else: + return IndexLambda(expr=expr.expr, + shape=new_shape, + dtype=expr.dtype, + bindings=new_bindings, + axes=expr.axes, + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_placeholder(self, expr: Placeholder, *args: P.args, **kwargs: P.kwargs) -> Array: assert expr.name is not None - return Placeholder(name=expr.name, - shape=self.rec_idx_or_size_tuple(expr.shape, - *args, **kwargs), - dtype=expr.dtype, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs) + if new_shape is expr.shape: + return expr + else: + return Placeholder(name=expr.name, + shape=new_shape, + dtype=expr.dtype, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack, *args: P.args, **kwargs: P.kwargs) -> Array: - arrays = tuple( + new_arrays: tuple[Array, ...] = tuple( _verify_is_array(self.rec(arr, *args, **kwargs)) for arr in expr.arrays) - return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + return expr + else: + return Stack(arrays=new_arrays, axis=expr.axis, axes=expr.axes, + tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_concatenate(self, expr: Concatenate, *args: P.args, **kwargs: P.kwargs) -> Array: - arrays = tuple( + new_arrays: tuple[Array, ...] = tuple( _verify_is_array(self.rec(arr, *args, **kwargs)) for arr in expr.arrays) - return Concatenate(arrays=arrays, axis=expr.axis, - axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + return expr + else: + return Concatenate(arrays=new_arrays, axis=expr.axis, + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_roll(self, expr: Roll, *args: P.args, **kwargs: P.kwargs) -> Array: - return Roll(array=_verify_is_array(self.rec(expr.array, *args, **kwargs)), - shift=expr.shift, - axis=expr.axis, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array, *args, **kwargs)) + if new_ary is expr.array: + return expr + else: + return Roll(array=new_ary, + shift=expr.shift, + axis=expr.axis, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation, *args: P.args, **kwargs: P.kwargs) -> Array: - return AxisPermutation(array=_verify_is_array( - self.rec(expr.array, *args, **kwargs)), - axis_permutation=expr.axis_permutation, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array, *args, **kwargs)) + if new_ary is expr.array: + return expr + else: + return AxisPermutation(array=new_ary, + axis_permutation=expr.axis_permutation, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase, *args: P.args, **kwargs: P.kwargs) -> Array: assert isinstance(expr, _SuppliedAxesAndTagsMixin) - return type(expr)(_verify_is_array(self.rec(expr.array, *args, **kwargs)), - indices=self.rec_idx_or_size_tuple(expr.indices, - *args, **kwargs), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array, *args, **kwargs)) + new_indices = self.rec_idx_or_size_tuple(expr.indices, *args, **kwargs) + if new_ary is expr.array and new_indices is expr.indices: + return expr + else: + return type(expr)(new_ary, + indices=new_indices, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_basic_index(self, expr: BasicIndex, *args: P.args, **kwargs: P.kwargs) -> Array: @@ -1073,98 +1204,141 @@ def map_non_contiguous_advanced_index(self, def map_data_wrapper(self, expr: DataWrapper, *args: P.args, **kwargs: P.kwargs) -> Array: - return DataWrapper( - data=expr.data, - shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs) + if new_shape is expr.shape: + return expr + else: + return DataWrapper( + data=expr.data, + shape=new_shape, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_size_param(self, expr: SizeParam, *args: P.args, **kwargs: P.kwargs) -> Array: assert expr.name is not None - return SizeParam(name=expr.name, axes=expr.axes, tags=expr.tags) + return expr def map_einsum(self, expr: Einsum, *args: P.args, **kwargs: P.kwargs) -> Array: - return Einsum(expr.access_descriptors, - tuple(_verify_is_array( - self.rec(arg, *args, **kwargs)) for arg in expr.args), - axes=expr.axes, - redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - def map_named_array(self, - expr: NamedArray, *args: P.args, **kwargs: P.kwargs) -> Array: - container = self.rec(expr._container, *args, **kwargs) - assert isinstance(container, AbstractResultWithNamedArrays) - return type(expr)(container, - expr.name, + new_args: tuple[Array, ...] = tuple( + _verify_is_array(self.rec(arg, *args, **kwargs)) for arg in expr.args) + if all( + new_arg is arg + for arg, new_arg in zip(expr.args, new_args, strict=True)): + return expr + else: + return Einsum(expr.access_descriptors, + new_args, axes=expr.axes, + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, tags=expr.tags, non_equality_tags=expr.non_equality_tags) + def map_named_array(self, + expr: NamedArray, *args: P.args, **kwargs: P.kwargs) -> Array: + new_container = self.rec(expr._container, *args, **kwargs) + assert isinstance(new_container, AbstractResultWithNamedArrays) + if new_container is expr._container: + return expr + else: + return type(expr)(new_container, + expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays, *args: P.args, **kwargs: P.kwargs ) -> DictOfNamedArrays: - return DictOfNamedArrays({key: _verify_is_array( - self.rec(val.expr, *args, **kwargs)) - for key, val in expr.items()}, - tags=expr.tags, - ) + new_data = { + key: _verify_is_array(self.rec(val.expr, *args, **kwargs)) + for key, val in expr.items()} + if all( + new_data_val is val.expr + for val, new_data_val in zip( + expr.values(), + new_data.values(), + strict=True)): + return expr + else: + return DictOfNamedArrays(new_data, tags=expr.tags) def map_loopy_call(self, expr: LoopyCall, *args: P.args, **kwargs: P.kwargs) -> LoopyCall: - bindings: Mapping[Any, Any] = immutabledict( + new_bindings: Mapping[Any, Any] = immutabledict( {name: (self.rec(subexpr, *args, **kwargs) if isinstance(subexpr, Array) else subexpr) for name, subexpr in sorted(expr.bindings.items())}) - - return LoopyCall(translation_unit=expr.translation_unit, - bindings=bindings, - entrypoint=expr.entrypoint, - tags=expr.tags, - ) + if ( + frozenset(new_bindings.keys()) == frozenset(expr.bindings.keys()) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings)): + return expr + else: + return LoopyCall(translation_unit=expr.translation_unit, + bindings=new_bindings, + entrypoint=expr.entrypoint, + tags=expr.tags, + ) def map_loopy_call_result(self, expr: LoopyCallResult, *args: P.args, **kwargs: P.kwargs) -> Array: - rec_loopy_call = self.rec(expr._container, *args, **kwargs) - assert isinstance(rec_loopy_call, LoopyCall) - return LoopyCallResult( - _container=rec_loopy_call, - name=expr.name, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_container = self.rec(expr._container, *args, **kwargs) + assert isinstance(new_container, LoopyCall) + if new_container is expr._container: + return expr + else: + return LoopyCallResult( + _container=new_container, + name=expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_reshape(self, expr: Reshape, *args: P.args, **kwargs: P.kwargs) -> Array: - return Reshape(_verify_is_array(self.rec(expr.array, *args, **kwargs)), - newshape=self.rec_idx_or_size_tuple(expr.newshape, - *args, **kwargs), - order=expr.order, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array, *args, **kwargs)) + new_newshape = self.rec_idx_or_size_tuple(expr.newshape, *args, **kwargs) + if new_ary is expr.array and new_newshape is expr.newshape: + return expr + else: + return Reshape(new_ary, + newshape=new_newshape, + order=expr.order, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder, *args: P.args, **kwargs: P.kwargs) -> Array: - return DistributedSendRefHolder( - send=DistributedSend( - data=_verify_is_array(self.rec(expr.send.data, *args, **kwargs)), - dest_rank=expr.send.dest_rank, - comm_tag=expr.send.comm_tag), - passthrough_data=_verify_is_array( - self.rec(expr.passthrough_data, *args, **kwargs))) + new_send_data = _verify_is_array(self.rec(expr.send.data, *args, **kwargs)) + if new_send_data is expr.send.data: + new_send = expr.send + else: + new_send = DistributedSend( + data=new_send_data, + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag) + new_passthrough = _verify_is_array( + self.rec(expr.passthrough_data, *args, **kwargs)) + if new_send is expr.send and new_passthrough is expr.passthrough_data: + return expr + else: + return DistributedSendRefHolder(new_send, new_passthrough) def map_distributed_recv(self, expr: DistributedRecv, *args: P.args, **kwargs: P.kwargs) -> Array: - return DistributedRecv( - src_rank=expr.src_rank, comm_tag=expr.comm_tag, - shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs), - dtype=expr.dtype, tags=expr.tags, axes=expr.axes, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs) + if new_shape is expr.shape: + return expr + else: + return DistributedRecv( + src_rank=expr.src_rank, comm_tag=expr.comm_tag, + shape=new_shape, dtype=expr.dtype, tags=expr.tags, + axes=expr.axes, non_equality_tags=expr.non_equality_tags) def map_function_definition( self, expr: FunctionDefinition, @@ -1176,17 +1350,27 @@ def map_function_definition( def map_call(self, expr: Call, *args: P.args, **kwargs: P.kwargs) -> AbstractResultWithNamedArrays: - return Call(self.rec_function_definition(expr.function, *args, **kwargs), - immutabledict({name: self.rec(bnd, *args, **kwargs) - for name, bnd in expr.bindings.items()}), - tags=expr.tags, - ) + new_function = self.rec_function_definition(expr.function, *args, **kwargs) + new_bindings = { + name: self.rec(bnd, *args, **kwargs) + for name, bnd in expr.bindings.items()} + if ( + new_function is expr.function + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values(), + strict=True))): + return expr + else: + return Call(new_function, immutabledict(new_bindings), tags=expr.tags) def map_named_call_result(self, expr: NamedCallResult, *args: P.args, **kwargs: P.kwargs) -> Array: - call = self.rec(expr._container, *args, **kwargs) - assert isinstance(call, Call) - return call[expr.name] + new_call = self.rec(expr._container, *args, **kwargs) + assert isinstance(new_call, Call) + return new_call[expr.name] # }}} From a882f484bc8b919404568689f3b4794537a180cb Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 20 Sep 2024 12:03:14 -0500 Subject: [PATCH 148/178] add Deduplicator --- pytato/transform/__init__.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index f60d2e3a8..c87435672 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -101,6 +101,7 @@ .. autoclass:: TransformMapperWithExtraArgs .. autoclass:: CopyMapper .. autoclass:: CopyMapperWithExtraArgs +.. autoclass:: Deduplicator .. autoclass:: CombineMapper .. autoclass:: DependencyMapper .. autoclass:: InputGatherer @@ -1375,6 +1376,28 @@ def map_named_call_result(self, expr: NamedCallResult, # }}} +# {{{ Deduplicator + +class Deduplicator(CopyMapper): + """Removes duplicate nodes from an expression.""" + def __init__( + self, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + super().__init__( + err_on_collision=False, err_on_created_duplicate=False, + _cache=_cache, + _function_cache=_function_cache) + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + _function_cache=cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) + +# }}} + + # {{{ CombineMapper class CombineMapper(CachedMapper[ResultT, FunctionResultT, []]): From c42fae337075619f8c68260f4ef0324759807afe Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 11 Jun 2024 10:58:24 -0500 Subject: [PATCH 149/178] avoid unnecessary duplication in InlineMarker --- pytato/transform/calls.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index 34f89cbc1..298b5351b 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -104,7 +104,11 @@ class InlineMarker(CopyMapper): Primary mapper for :func:`tag_all_calls_to_be_inlined`. """ def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: - return super().map_call(expr).tagged(InlineCallTag()) + rec_expr = super().map_call(expr) + if rec_expr.tags_of_type(InlineCallTag): + return rec_expr + else: + return rec_expr.tagged(InlineCallTag()) def inline_calls(expr: ArrayOrNames) -> ArrayOrNames: From 9d899b09e410e9a6af3ce426c8e61d972a382ece Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 27 Aug 2024 15:02:08 -0500 Subject: [PATCH 150/178] avoid duplication in tagged() for Axis/ReductionDescriptor/_SuppliedAxesAndTagsMixin --- pytato/array.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 8f92e5118..73d970a2e 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -467,8 +467,11 @@ class Axis(Taggable): tags: frozenset[Tag] def _with_new_tags(self, tags: frozenset[Tag]) -> Axis: - from dataclasses import replace - return replace(self, tags=tags) + if tags != self.tags: + from dataclasses import replace + return replace(self, tags=tags) + else: + return self @dataclasses.dataclass(frozen=True) @@ -480,8 +483,11 @@ class ReductionDescriptor(Taggable): tags: frozenset[Tag] def _with_new_tags(self, tags: frozenset[Tag]) -> ReductionDescriptor: - from dataclasses import replace - return replace(self, tags=tags) + if tags != self.tags: + from dataclasses import replace + return replace(self, tags=tags) + else: + return self @array_dataclass() @@ -865,7 +871,10 @@ class _SuppliedAxesAndTagsMixin(Taggable): default=frozenset()) def _with_new_tags(self: Self, tags: frozenset[Tag]) -> Self: - return dataclasses.replace(self, tags=tags) + if tags != self.tags: + return dataclasses.replace(self, tags=tags) + else: + return self @dataclasses.dataclass(frozen=True, eq=False, repr=False) From 858a1cd5352d5550e7662b83c07ae6d85261374e Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 11 Jun 2024 12:48:20 -0500 Subject: [PATCH 151/178] avoid duplication in Array.with_tagged_axis --- pytato/array.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 73d970a2e..c246a92a6 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -839,10 +839,14 @@ def with_tagged_axis(self, iaxis: int, """ Returns a copy of *self* with *iaxis*-th axis tagged with *tags*. """ - new_axes = (self.axes[:iaxis] - + (self.axes[iaxis].tagged(tags),) - + self.axes[iaxis+1:]) - return self.copy(axes=new_axes) + new_axis = self.axes[iaxis].tagged(tags) + if new_axis is not self.axes[iaxis]: + new_axes = (self.axes[:iaxis] + + (self.axes[iaxis].tagged(tags),) + + self.axes[iaxis+1:]) + return self.copy(axes=new_axes) + else: + return self @memoize_method def __repr__(self) -> str: From a17cc40c925e671ba9ae3a80b52cb896293ca4fa Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 11 Jun 2024 12:48:54 -0500 Subject: [PATCH 152/178] avoid duplication in with_tagged_reduction for IndexLambda/Einsum --- pytato/array.py | 58 ++++++++++++++++++++++++++----------------------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index c246a92a6..53d5215b8 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1127,20 +1127,22 @@ def with_tagged_reduction(self, f" '{self.var_to_reduction_descr.keys()}'," f" got '{reduction_variable}'.") - assert isinstance(self.var_to_reduction_descr, immutabledict) - new_var_to_redn_descr = dict(self.var_to_reduction_descr) - new_var_to_redn_descr[reduction_variable] = \ - self.var_to_reduction_descr[reduction_variable].tagged(tags) - - return type(self)(expr=self.expr, - shape=self.shape, - dtype=self.dtype, - bindings=self.bindings, - axes=self.axes, - var_to_reduction_descr=immutabledict - (new_var_to_redn_descr), - tags=self.tags, - non_equality_tags=self.non_equality_tags) + new_redn_descr = self.var_to_reduction_descr[reduction_variable].tagged(tags) + if new_redn_descr is not self.var_to_reduction_descr[reduction_variable]: + assert isinstance(self.var_to_reduction_descr, immutabledict) + new_var_to_redn_descr = dict(self.var_to_reduction_descr) + new_var_to_redn_descr[reduction_variable] = new_redn_descr + return type(self)(expr=self.expr, + shape=self.shape, + dtype=self.dtype, + bindings=self.bindings, + axes=self.axes, + var_to_reduction_descr=immutabledict + (new_var_to_redn_descr), + tags=self.tags, + non_equality_tags=self.non_equality_tags) + else: + return self # }}} @@ -1291,19 +1293,21 @@ def with_tagged_reduction(self, # }}} - assert isinstance(self.redn_axis_to_redn_descr, immutabledict) - new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr) - new_redn_axis_to_redn_descr[redn_axis] = \ - self.redn_axis_to_redn_descr[redn_axis].tagged(tags) - - return type(self)(access_descriptors=self.access_descriptors, - args=self.args, - axes=self.axes, - redn_axis_to_redn_descr=immutabledict - (new_redn_axis_to_redn_descr), - tags=self.tags, - non_equality_tags=self.non_equality_tags, - ) + new_redn_descr = self.redn_axis_to_redn_descr[redn_axis].tagged(tags) + if new_redn_descr is not self.redn_axis_to_redn_descr[redn_axis]: + assert isinstance(self.redn_axis_to_redn_descr, immutabledict) + new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr) + new_redn_axis_to_redn_descr[redn_axis] = new_redn_descr + return type(self)(access_descriptors=self.access_descriptors, + args=self.args, + axes=self.axes, + redn_axis_to_redn_descr=immutabledict + (new_redn_axis_to_redn_descr), + tags=self.tags, + non_equality_tags=self.non_equality_tags, + ) + else: + return self EINSUM_FIRST_INDEX = re.compile(r"^\s*((?P[a-zA-Z])|(?P\.\.\.))\s*") From 210f248e24c416c0fbc56adef20d7a58cba8034a Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 10 Jun 2024 15:35:47 -0500 Subject: [PATCH 153/178] attempt to avoid duplication in CodeGenPreprocessor --- pytato/array.py | 52 ++++--- pytato/codegen.py | 106 +++++++++----- pytato/transform/lower_to_index_lambda.py | 169 ++++++++++++++-------- 3 files changed, 206 insertions(+), 121 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 53d5215b8..1bcfbe696 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1177,6 +1177,34 @@ class EinsumReductionAxis(EinsumAxisDescriptor): dim: int +def _get_einsum_access_descr_to_axis_len( + access_descriptors: tuple[tuple[EinsumAxisDescriptor, ...], ...], + args: tuple[Array, ...], + ) -> Mapping[EinsumAxisDescriptor, ShapeComponent]: + from pytato.utils import are_shape_components_equal + descr_to_axis_len: dict[EinsumAxisDescriptor, ShapeComponent] = {} + + for access_descrs, arg in zip(access_descriptors, + args, strict=True): + assert arg.ndim == len(access_descrs) + for arg_axis_len, descr in zip(arg.shape, access_descrs, strict=True): + if descr in descr_to_axis_len: + seen_axis_len = descr_to_axis_len[descr] + + if not are_shape_components_equal(seen_axis_len, + arg_axis_len): + if are_shape_components_equal(arg_axis_len, 1): + # this axis would be broadcasted + pass + else: + assert are_shape_components_equal(seen_axis_len, 1) + descr_to_axis_len[descr] = arg_axis_len + else: + descr_to_axis_len[descr] = arg_axis_len + + return immutabledict(descr_to_axis_len) + + @array_dataclass() class Einsum(_SuppliedAxesAndTagsMixin, Array): """ @@ -1224,28 +1252,8 @@ def __post_init__(self) -> None: @memoize_method def _access_descr_to_axis_len(self ) -> Mapping[EinsumAxisDescriptor, ShapeComponent]: - from pytato.utils import are_shape_components_equal - descr_to_axis_len: dict[EinsumAxisDescriptor, ShapeComponent] = {} - - for access_descrs, arg in zip(self.access_descriptors, - self.args, strict=True): - assert arg.ndim == len(access_descrs) - for arg_axis_len, descr in zip(arg.shape, access_descrs, strict=True): - if descr in descr_to_axis_len: - seen_axis_len = descr_to_axis_len[descr] - - if not are_shape_components_equal(seen_axis_len, - arg_axis_len): - if are_shape_components_equal(arg_axis_len, 1): - # this axis would be broadcasted - pass - else: - assert are_shape_components_equal(seen_axis_len, 1) - descr_to_axis_len[descr] = arg_axis_len - else: - descr_to_axis_len[descr] = arg_axis_len - - return immutabledict(descr_to_axis_len) + return _get_einsum_access_descr_to_axis_len( + self.access_descriptors, self.args) @cached_property def shape(self) -> ShapeType: diff --git a/pytato/codegen.py b/pytato/codegen.py index cb957f076..f01296e1b 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -141,54 +141,72 @@ def __init__( _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) + super().__init__( + # ToIndexLambdaMixin operates on certain array types for which `shape` + # is a derived property (e.g. BasicIndex). For these types, `shape` + # is an expression that may contain duplicate nodes. Mappers do not + # traverse properties, so these expressions are not subject to any prior + # deduplication. Once transformed into an IndexLambda, however, `shape` + # becomes a field and is subject to traversal and duplication checks. + # Without `err_on_collision=False`, these duplicates would lead to + # collision errors. + err_on_collision=False, + _cache=_cache, _function_cache=_function_cache) self.bound_arguments: dict[str, DataInterface] = {} self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator() self.target = target self.kernels_seen: dict[str, lp.LoopKernel] = kernels_seen or {} def map_size_param(self, expr: SizeParam) -> Array: - name = expr.name - assert name is not None - return SizeParam( # pylint: disable=missing-kwoa - name=name, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + assert expr.name is not None + return expr def map_placeholder(self, expr: Placeholder) -> Array: - name = expr.name - if name is None: - name = self.var_name_gen("_pt_in") - return Placeholder(name=name, - shape=tuple(self.rec(s) if isinstance(s, Array) else s - for s in expr.shape), - dtype=expr.dtype, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_name = expr.name + if new_name is None: + new_name = self.var_name_gen("_pt_in") + new_shape = self.rec_idx_or_size_tuple(expr.shape) + if ( + new_name is expr.name + and new_shape is expr.shape): + return expr + else: + return Placeholder(name=new_name, + shape=new_shape, + dtype=expr.dtype, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: from pytato.target.loopy import LoopyTarget if not isinstance(self.target, LoopyTarget): raise ValueError("Got a LoopyCall for a non-loopy target.") - translation_unit = expr.translation_unit.copy( - target=self.target.get_loopy_target()) + new_target = self.target.get_loopy_target() + + # FIXME: Can't use "is" here because targets aren't unique. Is it OK to + # use the existing target if it's equal to self.target.get_loopy_target()? + # If not, may have to set err_on_created_duplicate=False + if new_target == expr.translation_unit.target: + new_translation_unit = expr.translation_unit + else: + new_translation_unit = expr.translation_unit.copy(target=new_target) namegen = UniqueNameGenerator(set(self.kernels_seen)) - entrypoint = expr.entrypoint + new_entrypoint = expr.entrypoint # {{{ eliminate callable name collision - for name, clbl in translation_unit.callables_table.items(): + for name, clbl in new_translation_unit.callables_table.items(): if isinstance(clbl, lp.CallableKernel): assert isinstance(name, str) if name in self.kernels_seen and ( - translation_unit[name] != self.kernels_seen[name]): + new_translation_unit[name] != self.kernels_seen[name]): # callee name collision => must rename # {{{ see if it's one of the other kernels for other_knl in self.kernels_seen.values(): - if other_knl.copy(name=name) == translation_unit[name]: + if other_knl.copy(name=name) == new_translation_unit[name]: new_name = other_knl.name break else: @@ -198,37 +216,55 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: # }}} - if name == entrypoint: + if name == new_entrypoint: # if the colliding name is the entrypoint, then rename the # entrypoint as well. - entrypoint = new_name + new_entrypoint = new_name - translation_unit = lp.rename_callable( - translation_unit, name, new_name) + new_translation_unit = lp.rename_callable( + new_translation_unit, name, new_name) name = new_name self.kernels_seen[name] = clbl.subkernel # }}} - bindings: Mapping[str, Any] = immutabledict( + new_bindings: Mapping[str, Any] = immutabledict( {name: (self.rec(subexpr) if isinstance(subexpr, Array) else subexpr) for name, subexpr in sorted(expr.bindings.items())}) - return LoopyCall(translation_unit=translation_unit, - bindings=bindings, - entrypoint=entrypoint, - tags=expr.tags - ) + assert ( + new_entrypoint is expr.entrypoint + or new_entrypoint != expr.entrypoint) + for bnd, new_bnd in zip( + expr.bindings.values(), new_bindings.values(), strict=True): + assert new_bnd is bnd or new_bnd != bnd + + if ( + new_translation_unit == expr.translation_unit + and ( + frozenset(new_bindings.keys()) + == frozenset(expr.bindings.keys())) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings) + and new_entrypoint is expr.entrypoint): + return expr + else: + return LoopyCall(translation_unit=new_translation_unit, + bindings=new_bindings, + entrypoint=new_entrypoint, + tags=expr.tags + ) def map_data_wrapper(self, expr: DataWrapper) -> Array: name = _generate_name_for_temp(expr, self.var_name_gen, "_pt_data") + shape = self.rec_idx_or_size_tuple(expr.shape) self.bound_arguments[name] = expr.data return Placeholder(name=name, - shape=tuple(self.rec(s) if isinstance(s, Array) else s - for s in expr.shape), + shape=shape, dtype=expr.dtype, axes=expr.axes, tags=expr.tags, diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 507a450cd..f71ccc010 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -53,15 +53,18 @@ ShapeComponent, ShapeType, Stack, + _get_einsum_access_descr_to_axis_len, ) from pytato.diagnostic import CannotBeLoweredToIndexLambda from pytato.scalar_expr import INT_CLASSES, ScalarExpression from pytato.tags import AssumeNonNegative -from pytato.transform import Mapper +from pytato.transform import IndexOrShapeExpr, Mapper from pytato.utils import normalized_slice_does_not_change_axis if TYPE_CHECKING: + from collections.abc import Mapping + import numpy as np @@ -126,16 +129,14 @@ def _generate_index_expressions( for old_size_till, old_stride in zip(old_size_tills, old_strides, strict=True)) -def _get_reshaped_indices(expr: Reshape) -> tuple[ScalarExpression, ...]: +def _get_reshaped_indices( + order: str, old_shape: ShapeType, new_shape: ShapeType + ) -> tuple[ScalarExpression, ...]: - if expr.order.upper() not in ["C", "F"]: + if order.upper() not in ["C", "F"]: raise NotImplementedError("Order expected to be 'C' or 'F'", " (case insensitive). Found order = ", - f"{expr.order}") - - order = expr.order - old_shape = expr.array.shape - new_shape = expr.shape + f"{order}") # index variables need to be unique and depend on the new shape length index_vars = [prim.Variable(f"_{i}") for i in range(len(new_shape))] @@ -143,7 +144,8 @@ def _get_reshaped_indices(expr: Reshape) -> tuple[ScalarExpression, ...]: # {{{ check for scalars if old_shape == (): - assert expr.size == 1 + from pytools import product + assert product(new_shape) == 1 return () if new_shape == (): @@ -256,10 +258,17 @@ def _get_reshaped_indices(expr: Reshape) -> tuple[ScalarExpression, ...]: class ToIndexLambdaMixin: - def _rec_shape(self, shape: ShapeType) -> ShapeType: - return tuple(self.rec(s) if isinstance(s, Array) - else s - for s in shape) + def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...] + ) -> tuple[IndexOrShapeExpr, ...]: + # type-ignore-reason: apparently mypy cannot substitute typevars + # here. + new_situp = tuple( + self.rec(s) if isinstance(s, Array) else s + for s in situp) + if all(new_s is s for s, new_s in zip(situp, new_situp, strict=True)): + return situp + else: + return new_situp # type: ignore[return-value] if TYPE_CHECKING: def rec( @@ -270,17 +279,27 @@ def rec( return super().rec( # type: ignore[no-any-return,misc] expr, *args, **kwargs) - def map_index_lambda(self, expr: IndexLambda) -> IndexLambda: - return IndexLambda(expr=expr.expr, - shape=self._rec_shape(expr.shape), - dtype=expr.dtype, - bindings=immutabledict({name: self.rec(bnd) - for name, bnd - in sorted(expr.bindings.items())}), - axes=expr.axes, - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + def map_index_lambda(self, expr: IndexLambda) -> Array: + new_shape = self.rec_idx_or_size_tuple(expr.shape) + new_bindings: Mapping[str, Array] = immutabledict({ + name: self.rec(subexpr) + for name, subexpr in sorted(expr.bindings.items())}) + if ( + new_shape is expr.shape + and frozenset(new_bindings.keys()) == frozenset(expr.bindings.keys()) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings)): + return expr + else: + return IndexLambda(expr=expr.expr, + shape=new_shape, + dtype=expr.dtype, + bindings=new_bindings, + axes=expr.axes, + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack) -> IndexLambda: subscript = tuple(prim.Variable(f"_{i}") @@ -305,11 +324,11 @@ def map_stack(self, expr: Stack) -> IndexLambda: subarray_expr, stack_expr) - bindings = {f"_in{i}": self.rec(array) - for i, array in enumerate(expr.arrays)} + bindings = {f"_in{i}": self.rec(ary) + for i, ary in enumerate(expr.arrays)} return IndexLambda(expr=stack_expr, - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, axes=expr.axes, bindings=immutabledict(bindings), @@ -328,10 +347,12 @@ def get_subscript(array_index: int, offset: ScalarExpression) -> Subscript: for i in range(len(expr.shape))] return Subscript(aggregate, tuple(index)) + rec_arrays: tuple[Array, ...] = tuple(self.rec(ary) for ary in expr.arrays) + lbounds: list[Any] = [0] - ubounds: list[Any] = [expr.arrays[0].shape[expr.axis]] + ubounds: list[Any] = [rec_arrays[0].shape[expr.axis]] - for i, array in enumerate(expr.arrays[1:], start=1): + for i, array in enumerate(rec_arrays[1:], start=1): ubounds.append(ubounds[i-1]+array.shape[expr.axis]) lbounds.append(ubounds[i-1]) @@ -354,11 +375,11 @@ def get_subscript(array_index: int, offset: ScalarExpression) -> Subscript: subarray_expr, concat_expr) - bindings = {f"_in{i}": self.rec(array) - for i, array in enumerate(expr.arrays)} + bindings = {f"_in{i}": ary + for i, ary in enumerate(rec_arrays)} return IndexLambda(expr=concat_expr, - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, bindings=immutabledict(bindings), axes=expr.axes, @@ -377,7 +398,9 @@ def map_einsum(self, expr: Einsum) -> IndexLambda: dim_to_index_lambda_components, ) - bindings = {f"_in{k}": self.rec(arg) for k, arg in enumerate(expr.args)} + rec_args: tuple[Array, ...] = tuple(self.rec(arg) for arg in expr.args) + + bindings = {f"_in{k}": arg for k, arg in enumerate(rec_args)} redn_bounds: dict[str, tuple[ScalarExpression, ScalarExpression]] = {} args_as_pym_expr: list[prim.Subscript] = [] namegen = UniqueNameGenerator(set(bindings)) @@ -385,13 +408,16 @@ def map_einsum(self, expr: Einsum) -> IndexLambda: # {{{ add bindings coming from the shape expressions + access_descr_to_axis_len = _get_einsum_access_descr_to_axis_len( + expr.access_descriptors, rec_args) + for access_descr, (iarg, arg) in zip(expr.access_descriptors, - enumerate(expr.args), strict=True): + enumerate(rec_args), strict=True): subscript_indices: list[ArithmeticExpression] = [] for iaxis, axis in enumerate(access_descr): if not are_shape_components_equal( arg.shape[iaxis], - expr._access_descr_to_axis_len()[axis]): + access_descr_to_axis_len[axis]): # axis is broadcasted assert are_shape_components_equal(arg.shape[iaxis], 1) subscript_indices.append(0) @@ -432,7 +458,7 @@ def map_einsum(self, expr: Einsum) -> IndexLambda: immutabledict(redn_bounds)) return IndexLambda(expr=inner_expr, - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, bindings=immutabledict(bindings), axes=expr.axes, @@ -443,12 +469,14 @@ def map_einsum(self, expr: Einsum) -> IndexLambda: def map_roll(self, expr: Roll) -> IndexLambda: from pytato.utils import dim_to_index_lambda_components + rec_array = self.rec(expr.array) + index_expr: prim.ExpressionNode = prim.Variable("_in0") indices: list[ArithmeticExpression] = [ prim.Variable(f"_{d}") for d in range(expr.ndim)] axis = expr.axis axis_len_expr, bindings = dim_to_index_lambda_components( - expr.shape[axis], + rec_array.shape[axis], UniqueNameGenerator({"_in0"})) # Mypy has a point: the type system does not prove that the operands are @@ -459,13 +487,12 @@ def map_roll(self, expr: Roll) -> IndexLambda: index_expr = index_expr[tuple(indices)] # type-ignore-reason: `bindings` was returned as Dict[str, SizeParam] - bindings["_in0"] = expr.array # type: ignore[assignment] + bindings["_in0"] = rec_array # type: ignore[assignment] return IndexLambda(expr=index_expr, - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, - bindings=immutabledict({name: self.rec(bnd) - for name, bnd in bindings.items()}), + bindings=immutabledict(bindings), axes=expr.axes, var_to_reduction_descr=immutabledict(), tags=expr.tags, @@ -476,27 +503,30 @@ def map_contiguous_advanced_index(self, ) -> IndexLambda: from pytato.utils import get_indexing_expression, get_shape_after_broadcasting + rec_array = self.rec(expr.array) + rec_indices = self.rec_idx_or_size_tuple(expr.indices) + i_adv_indices = tuple(i - for i, idx_expr in enumerate(expr.indices) + for i, idx_expr in enumerate(rec_indices) if isinstance(idx_expr, (Array, *INT_CLASSES))) adv_idx_shape = get_shape_after_broadcasting([ - cast("Array | int | np.integer[Any]", expr.indices[i_idx]) + cast("Array | int | np.integer[Any]", rec_indices[i_idx]) for i_idx in i_adv_indices]) vng = UniqueNameGenerator() indices: list[ArithmeticExpression] = [] in_ary = vng("in") - bindings = {in_ary: self.rec(expr.array)} + bindings = {in_ary: rec_array} islice_idx = 0 for i_idx, (idx, axis_len) in enumerate( - zip(expr.indices, expr.array.shape, strict=True)): + zip(rec_indices, rec_array.shape, strict=True)): if isinstance(idx, INT_CLASSES): if isinstance(axis_len, INT_CLASSES): indices.append(idx % axis_len) else: bnd_name = vng("in") - bindings[bnd_name] = self.rec(axis_len) + bindings[bnd_name] = axis_len indices.append(idx % prim.Variable(bnd_name)) elif isinstance(idx, NormalizedSlice): if normalized_slice_does_not_change_axis(idx, axis_len): @@ -508,7 +538,7 @@ def map_contiguous_advanced_index(self, elif isinstance(idx, Array): if isinstance(axis_len, INT_CLASSES): bnd_name = vng("in") - bindings[bnd_name] = self.rec(idx) + bindings[bnd_name] = idx indirect_idx_expr: ArithmeticExpression = prim.Subscript( prim.Variable(bnd_name), get_indexing_expression( @@ -536,7 +566,7 @@ def map_contiguous_advanced_index(self, return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), bindings=immutabledict(bindings), - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, axes=expr.axes, var_to_reduction_descr=immutabledict(), @@ -547,28 +577,32 @@ def map_contiguous_advanced_index(self, def map_non_contiguous_advanced_index( self, expr: AdvancedIndexInNoncontiguousAxes) -> IndexLambda: from pytato.utils import get_indexing_expression, get_shape_after_broadcasting + + rec_array = self.rec(expr.array) + rec_indices = self.rec_idx_or_size_tuple(expr.indices) + i_adv_indices = tuple(i - for i, idx_expr in enumerate(expr.indices) + for i, idx_expr in enumerate(rec_indices) if isinstance(idx_expr, (Array, *INT_CLASSES))) adv_idx_shape = get_shape_after_broadcasting([ - cast("Array | int | np.integer[Any]", expr.indices[i_idx]) + cast("Array | int | np.integer[Any]", rec_indices[i_idx]) for i_idx in i_adv_indices]) vng = UniqueNameGenerator() indices: list[ArithmeticExpression] = [] in_ary = vng("in") - bindings = {in_ary: self.rec(expr.array)} + bindings = {in_ary: rec_array} islice_idx = len(adv_idx_shape) - for idx, axis_len in zip(expr.indices, expr.array.shape, strict=True): + for idx, axis_len in zip(rec_indices, rec_array.shape, strict=True): if isinstance(idx, INT_CLASSES): if isinstance(axis_len, INT_CLASSES): indices.append(idx % axis_len) else: bnd_name = vng("in") - bindings[bnd_name] = self.rec(axis_len) + bindings[bnd_name] = axis_len indices.append(idx % prim.Variable(bnd_name)) elif isinstance(idx, NormalizedSlice): if normalized_slice_does_not_change_axis(idx, axis_len): @@ -580,7 +614,7 @@ def map_non_contiguous_advanced_index( elif isinstance(idx, Array): if isinstance(axis_len, INT_CLASSES): bnd_name = vng("in") - bindings[bnd_name] = self.rec(idx) + bindings[bnd_name] = idx indirect_idx_expr: ArithmeticExpression = prim.Subscript( prim.Variable(bnd_name), @@ -605,7 +639,7 @@ def map_non_contiguous_advanced_index( return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), bindings=immutabledict(bindings), - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, axes=expr.axes, var_to_reduction_descr=immutabledict(), @@ -614,20 +648,23 @@ def map_non_contiguous_advanced_index( ) def map_basic_index(self, expr: BasicIndex) -> IndexLambda: + rec_array = self.rec(expr.array) + rec_indices = self.rec_idx_or_size_tuple(expr.indices) + vng = UniqueNameGenerator() indices: list[ArithmeticExpression] = [] in_ary = vng("in") - bindings = {in_ary: self.rec(expr.array)} + bindings = {in_ary: rec_array} islice_idx = 0 - for idx, axis_len in zip(expr.indices, expr.array.shape, strict=True): + for idx, axis_len in zip(rec_indices, rec_array.shape, strict=True): if isinstance(idx, INT_CLASSES): if isinstance(axis_len, INT_CLASSES): indices.append(idx % axis_len) else: bnd_name = vng("in") - bindings[bnd_name] = self.rec(axis_len) + bindings[bnd_name] = axis_len indices.append(idx % prim.Variable(bnd_name)) elif isinstance(idx, NormalizedSlice): if normalized_slice_does_not_change_axis(idx, axis_len): @@ -642,7 +679,7 @@ def map_basic_index(self, expr: BasicIndex) -> IndexLambda: return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), bindings=immutabledict(bindings), - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, axes=expr.axes, var_to_reduction_descr=immutabledict(), @@ -651,18 +688,22 @@ def map_basic_index(self, expr: BasicIndex) -> IndexLambda: ) def map_reshape(self, expr: Reshape) -> IndexLambda: - indices = _get_reshaped_indices(expr) + rec_array = self.rec(expr.array) + rec_newshape = self.rec_idx_or_size_tuple(expr.shape) + indices = _get_reshaped_indices(expr.order, rec_array.shape, rec_newshape) index_expr = prim.Variable("_in0")[tuple(indices)] return IndexLambda(expr=index_expr, - shape=self._rec_shape(expr.shape), + shape=rec_newshape, dtype=expr.dtype, - bindings=immutabledict({"_in0": self.rec(expr.array)}), + bindings=immutabledict({"_in0": rec_array}), axes=expr.axes, var_to_reduction_descr=immutabledict(), tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation) -> IndexLambda: + rec_array = self.rec(expr.array) + indices: list[ArithmeticExpression | None] = [None] * expr.ndim for from_index, to_index in enumerate(expr.axis_permutation): indices[to_index] = prim.Variable(f"_{from_index}") @@ -671,9 +712,9 @@ def map_axis_permutation(self, expr: AxisPermutation) -> IndexLambda: cast("tuple[ArithmeticExpression]", tuple(indices))] return IndexLambda(expr=index_expr, - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, - bindings=immutabledict({"_in0": self.rec(expr.array)}), + bindings=immutabledict({"_in0": rec_array}), axes=expr.axes, var_to_reduction_descr=immutabledict(), tags=expr.tags, From c9157340dbd77709970ddaab3425c5f5f20b0c3b Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 3 Jul 2024 15:54:52 -0500 Subject: [PATCH 154/178] limit PlaceholderSubstitutor to one call stack frame --- pytato/transform/calls.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index 298b5351b..d9bc078fd 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -65,9 +65,10 @@ def __init__(self, substitutions: Mapping[str, Array]) -> None: def map_placeholder(self, expr: Placeholder) -> Array: return self.substitutions[expr.name] - def map_named_call_result(self, expr: NamedCallResult) -> NamedCallResult: - raise NotImplementedError( - "PlaceholderSubstitutor does not support functions.") + def map_function_definition( + self, expr: FunctionDefinition) -> FunctionDefinition: + # Only operates within the current stack frame + return expr class Inliner(CopyMapper): From 98802c27b2ae0654d7db0c92f2e54bff62a7e54b Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 11 Jul 2024 23:02:13 -0500 Subject: [PATCH 155/178] tweak Inliner/PlaceholderSubstitutor implementations --- pytato/transform/calls.py | 55 ++++++++++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index d9bc078fd..2cf8c6479 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -30,7 +30,9 @@ """ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast + +from typing_extensions import Self from pytato.array import ( AbstractResultWithNamedArrays, @@ -38,9 +40,14 @@ DictOfNamedArrays, Placeholder, ) -from pytato.function import Call, NamedCallResult +from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.tags import InlineCallTag -from pytato.transform import ArrayOrNames, CopyMapper, _verify_is_array +from pytato.transform import ( + ArrayOrNames, + CopyMapper, + TransformMapperCache, + _verify_is_array, +) if TYPE_CHECKING: @@ -55,6 +62,12 @@ class PlaceholderSubstitutor(CopyMapper): A mapping from the placeholder name to the array that it is to be substituted with. + + .. note:: + + This mapper does not deduplicate subexpressions that occur in both the mapped + expression and the substitutions. Must follow up with a + :class:`pytato.transform.Deduplicator` if duplicates need to be removed. """ def __init__(self, substitutions: Mapping[str, Array]) -> None: @@ -63,6 +76,9 @@ def __init__(self, substitutions: Mapping[str, Array]) -> None: self.substitutions = substitutions def map_placeholder(self, expr: Placeholder) -> Array: + # Can't call rec() to remove duplicates here, because the substituted-in + # expression may potentially contain unrelated placeholders whose names + # collide with the ones being replaced return self.substitutions[expr.name] def map_function_definition( @@ -75,21 +91,36 @@ class Inliner(CopyMapper): """ Primary mapper for :func:`inline_calls`. """ - def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: - # inline call sites within the callee. - new_expr = super().map_call(expr) - assert isinstance(new_expr, Call) + def __init__( + self, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + # Must disable collision/duplication checking because we're combining + # expressions that were previously in two different call stack frames + # (and were thus cached separately) + super().__init__( + err_on_collision=False, + err_on_created_duplicate=False, + _cache=_cache, + _function_cache=_function_cache) + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + _function_cache=cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: if expr.tags_of_type(InlineCallTag): - substitutor = PlaceholderSubstitutor(new_expr.bindings) + substitutor = PlaceholderSubstitutor(expr.bindings) return DictOfNamedArrays( - {name: _verify_is_array(substitutor.rec(ret)) - for name, ret in new_expr.function.returns.items()}, - tags=new_expr.tags + {name: _verify_is_array(self.rec(substitutor(ret))) + for name, ret in expr.function.returns.items()}, + tags=expr.tags ) else: - return new_expr + return super().map_call(expr) def map_named_call_result(self, expr: NamedCallResult) -> Array: new_call_or_inlined_expr = self.rec(expr._container) From 8bfc0d2730a2c89a070e39ee1a446022588656a2 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 16 Jul 2024 13:31:48 -0500 Subject: [PATCH 156/178] use context manager to avoid leaking traceback tag setting in test --- test/test_pytato.py | 139 +++++++++++++++++++++++--------------------- 1 file changed, 73 insertions(+), 66 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index da176f124..e0e1f90d5 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -29,6 +29,7 @@ import dataclasses import sys +from contextlib import contextmanager import numpy as np import pytest @@ -932,111 +933,117 @@ def test_einsum_dot_axes_has_correct_dim(): assert len(einsum.axes) == einsum.ndim -def test_created_at(): - pt.set_traceback_tag_enabled() +@contextmanager +def enable_traceback_tag(): + try: + pt.set_traceback_tag_enabled(True) + yield + finally: + pt.set_traceback_tag_enabled(False) - a = pt.make_placeholder("a", (10, 10), "float64") - b = pt.make_placeholder("b", (10, 10), "float64") - # res1 and res2 are defined on different lines and should have different - # CreatedAt tags. - res1 = a+b - res2 = a+b +def test_created_at(): + with enable_traceback_tag(): + a = pt.make_placeholder("a", (10, 10), "float64") + b = pt.make_placeholder("b", (10, 10), "float64") - # res3 and res4 are defined on the same line and should have the same - # CreatedAt tags. - res3 = a+b; res4 = a+b # noqa: E702 + # res1 and res2 are defined on different lines and should have different + # CreatedAt tags. + res1 = a+b + res2 = a+b - # {{{ Check that CreatedAt tags are handled correctly for equality/hashing + # res3 and res4 are defined on the same line and should have the same + # CreatedAt tags. + res3 = a+b; res4 = a+b # noqa: E702 - assert res1 == res2 == res3 == res4 - assert hash(res1) == hash(res2) == hash(res3) == hash(res4) + # {{{ Check that CreatedAt tags are handled correctly for equality/hashing - assert res1.non_equality_tags != res2.non_equality_tags - assert res3.non_equality_tags == res4.non_equality_tags - assert hash(res1.non_equality_tags) != hash(res2.non_equality_tags) - assert hash(res3.non_equality_tags) == hash(res4.non_equality_tags) + assert res1 == res2 == res3 == res4 + assert hash(res1) == hash(res2) == hash(res3) == hash(res4) - assert res1.tags == res2.tags == res3.tags == res4.tags - assert hash(res1.tags) == hash(res2.tags) == hash(res3.tags) == hash(res4.tags) + assert res1.non_equality_tags != res2.non_equality_tags + assert res3.non_equality_tags == res4.non_equality_tags + assert hash(res1.non_equality_tags) != hash(res2.non_equality_tags) + assert hash(res3.non_equality_tags) == hash(res4.non_equality_tags) - # }}} + assert res1.tags == res2.tags == res3.tags == res4.tags + assert hash(res1.tags) == hash(res2.tags) == hash(res3.tags) == hash(res4.tags) - from pytato.tags import CreatedAt + # }}} - created_tag = frozenset({tag - for tag in res1.non_equality_tags - if isinstance(tag, CreatedAt)}) + from pytato.tags import CreatedAt - assert len(created_tag) == 1 + created_tag = frozenset({tag + for tag in res1.non_equality_tags + if isinstance(tag, CreatedAt)}) - # {{{ Make sure the function name appears in the traceback + assert len(created_tag) == 1 - tag, = created_tag + # {{{ Make sure the function name appears in the traceback - found = False + tag, = created_tag - stacksummary = tag.traceback.to_stacksummary() - assert len(stacksummary) > 10 + found = False - for frame in tag.traceback.frames: - if frame.name == "test_created_at" and "a+b" in frame.line: - found = True - break + stacksummary = tag.traceback.to_stacksummary() + assert len(stacksummary) > 10 - assert found + for frame in tag.traceback.frames: + if frame.name == "test_created_at" and "a+b" in frame.line: + found = True + break - # }}} + assert found - # {{{ Make sure that CreatedAt tags are in the visualization + # }}} - from pytato.visualization import get_dot_graph - s = get_dot_graph(res1) - assert "test_created_at" in s - assert "a+b" in s + # {{{ Make sure that CreatedAt tags are in the visualization - # }}} + from pytato.visualization import get_dot_graph + s = get_dot_graph(res1) + assert "test_created_at" in s + assert "a+b" in s - # {{{ Make sure only a single CreatedAt tag is created + # }}} - old_tag = tag + # {{{ Make sure only a single CreatedAt tag is created - res1 = res1 + res2 + old_tag = tag - created_tag = frozenset({tag - for tag in res1.non_equality_tags - if isinstance(tag, CreatedAt)}) + res1 = res1 + res2 - assert len(created_tag) == 1 + created_tag = frozenset({tag + for tag in res1.non_equality_tags + if isinstance(tag, CreatedAt)}) - tag, = created_tag + assert len(created_tag) == 1 - # Tag should be recreated - assert tag != old_tag + tag, = created_tag - # }}} + # Tag should be recreated + assert tag != old_tag - # {{{ Make sure that copying preserves the tag + # }}} - old_tag = tag + # {{{ Make sure that copying preserves the tag - res1_new = pt.transform.map_and_copy(res1, lambda x: x) + old_tag = tag - created_tag = frozenset({tag - for tag in res1_new.non_equality_tags - if isinstance(tag, CreatedAt)}) + res1_new = pt.transform.map_and_copy(res1, lambda x: x) - assert len(created_tag) == 1 + created_tag = frozenset({tag + for tag in res1_new.non_equality_tags + if isinstance(tag, CreatedAt)}) - tag, = created_tag + assert len(created_tag) == 1 - assert old_tag == tag + tag, = created_tag - # }}} + assert old_tag == tag - # {{{ Test disabling traceback creation + # }}} - pt.set_traceback_tag_enabled(False) + # {{{ Test disabling traceback creation a = pt.make_placeholder("a", (10, 10), "float64") From 329c3882008d701df6f7b2cf86ba47e7bef61868 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 16 Jul 2024 16:57:54 -0500 Subject: [PATCH 157/178] refactor FFTRealizationMapper to avoid resetting cache in __init__ --- test/test_apps.py | 71 +++++++++++++++++++++++++---------------------- 1 file changed, 38 insertions(+), 33 deletions(-) diff --git a/test/test_apps.py b/test/test_apps.py index f39be848c..bdb3afc14 100644 --- a/test/test_apps.py +++ b/test/test_apps.py @@ -39,7 +39,7 @@ from pytools.tag import Tag, tag_dataclass import pytato as pt -from pytato.transform import CopyMapper, WalkMapper +from pytato.transform import CopyMapper, Deduplicator, WalkMapper # {{{ Trace an FFT @@ -78,40 +78,21 @@ def map_constant(self, expr): class FFTRealizationMapper(CopyMapper): - def __init__(self, fft_vec_gatherer): - super().__init__() - - self.fft_vec_gatherer = fft_vec_gatherer - - self.old_array_to_new_array = {} - levels = sorted(fft_vec_gatherer.level_to_arrays, reverse=True) - - lev = 0 - arrays = fft_vec_gatherer.level_to_arrays[lev] - self.finalized = False - - for lev in levels: - arrays = fft_vec_gatherer.level_to_arrays[lev] - rec_arrays = [self.rec(ary) for ary in arrays] - # reset cache so that the partial subs are not stored - self._cache.clear() - lev_array = pt.concatenate(rec_arrays, axis=0) - assert lev_array.shape == (fft_vec_gatherer.n,) - - startidx = 0 - for array in arrays: - size = array.shape[0] - sub_array = lev_array[startidx:startidx+size] - startidx += size - self.old_array_to_new_array[array] = sub_array - - assert startidx == fft_vec_gatherer.n - self.finalized = True + def __init__(self, old_array_to_new_array): + # Must use err_on_created_duplicate=False, because the use of ConstantSizer + # in map_index_lambda creates IndexLambdas that differ only in the type of + # their contained constants, which changes their identity but not their + # equality + super().__init__(err_on_created_duplicate=False) + self.old_array_to_new_array = old_array_to_new_array def map_index_lambda(self, expr): tags = expr.tags_of_type(FFTIntermediate) - if tags and (self.finalized or expr in self.old_array_to_new_array): - return self.old_array_to_new_array[expr] + if tags: + try: + return self.old_array_to_new_array[expr] + except KeyError: + pass return super().map_index_lambda( expr.copy(expr=ConstantSizer()(expr.expr))) @@ -122,6 +103,29 @@ def map_concatenate(self, expr): (ImplStored(), PrefixNamed("concat"))) +def make_fft_realization_mapper(fft_vec_gatherer): + old_array_to_new_array = {} + levels = sorted(fft_vec_gatherer.level_to_arrays, reverse=True) + + for lev in levels: + lev_mapper = FFTRealizationMapper(old_array_to_new_array) + arrays = fft_vec_gatherer.level_to_arrays[lev] + rec_arrays = [lev_mapper(ary) for ary in arrays] + lev_array = pt.concatenate(rec_arrays, axis=0) + assert lev_array.shape == (fft_vec_gatherer.n,) + + startidx = 0 + for array in arrays: + size = array.shape[0] + sub_array = lev_array[startidx:startidx+size] + startidx += size + old_array_to_new_array[array] = sub_array + + assert startidx == fft_vec_gatherer.n + + return FFTRealizationMapper(old_array_to_new_array) + + def test_trace_fft(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) @@ -134,10 +138,11 @@ def test_trace_fft(ctx_factory): wrap_intermediate_with_level=( lambda level, ary: ary.tagged(FFTIntermediate(level)))) + result = Deduplicator()(result) fft_vec_gatherer = FFTVectorGatherer(n) fft_vec_gatherer(result) - mapper = FFTRealizationMapper(fft_vec_gatherer) + mapper = make_fft_realization_mapper(fft_vec_gatherer) result = mapper(result) From 9ea411f301c7f35fbeeda6a49bfb7e00bd09c0a5 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 27 Aug 2024 16:10:14 -0500 Subject: [PATCH 158/178] add allow_duplicate_nodes option to RandomDAGContext in tests --- test/testlib.py | 47 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/test/testlib.py b/test/testlib.py index a28dec67e..7d58df480 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -101,6 +101,7 @@ def __init__( rng: np.random.Generator, axis_len: int, use_numpy: bool, + allow_duplicate_nodes: bool = False, additional_generators: ( Sequence[tuple[int, Callable[[RandomDAGContext], Array]]] | None) = None @@ -115,6 +116,7 @@ def __init__( self.axis_len = axis_len self.past_results: list[Array] = [] self.use_numpy = use_numpy + self.allow_duplicate_nodes = allow_duplicate_nodes if additional_generators is None: additional_generators = [] @@ -156,6 +158,14 @@ def make_random_reshape( def make_random_dag_inner(rdagc: RandomDAGContext) -> Any: + if not rdagc.use_numpy and not rdagc.allow_duplicate_nodes: + def dedup(expr: Array) -> Array: + return pt.transform._verify_is_array(pt.transform.Deduplicator()(expr)) + + else: + def dedup(expr: Array) -> Array: + return expr + rng = rdagc.rng max_prob_hardcoded = 1500 @@ -166,7 +176,7 @@ def make_random_dag_inner(rdagc: RandomDAGContext) -> Any: v = rng.integers(0, max_prob_hardcoded + additional_prob) if v < 600: - return make_random_constant(rdagc, naxes=rng.integers(1, 3)) + return dedup(make_random_constant(rdagc, naxes=rng.integers(1, 3))) elif v < 1000: op1 = make_random_dag(rdagc) @@ -189,9 +199,9 @@ def make_random_dag_inner(rdagc: RandomDAGContext) -> Any: # just inserted a few new 1-long axes. Those need to go before we # return. if which_op in ["maximum", "minimum"]: - return rdagc.np.squeeze(getattr(rdagc.np, which_op)(op1, op2)) + return dedup(rdagc.np.squeeze(getattr(rdagc.np, which_op)(op1, op2))) else: - return rdagc.np.squeeze(which_op(op1, op2)) + return dedup(rdagc.np.squeeze(which_op(op1, op2))) elif v < 1075: op1 = make_random_dag(rdagc) @@ -199,24 +209,26 @@ def make_random_dag_inner(rdagc: RandomDAGContext) -> Any: if op1.ndim <= 1 and op2.ndim <= 1: continue - return op1 @ op2 + return dedup(op1 @ op2) elif v < 1275: if not rdagc.past_results: continue - return rdagc.past_results[rng.integers(0, len(rdagc.past_results))] + return dedup( + rdagc.past_results[rng.integers(0, len(rdagc.past_results))]) elif v < max_prob_hardcoded: result = make_random_dag(rdagc) - return rdagc.np.transpose( + return dedup( + rdagc.np.transpose( result, - tuple(rng.permuted(list(range(result.ndim))))) + tuple(rng.permuted(list(range(result.ndim)))))) else: base_prob = max_prob_hardcoded for fake_prob, gen_func in rdagc.additional_generators: if base_prob <= v < base_prob + fake_prob: - return gen_func(rdagc) + return dedup(gen_func(rdagc)) base_prob += fake_prob @@ -237,6 +249,14 @@ def make_random_dag(rdagc: RandomDAGContext) -> Any: of the array are of length :attr:`RandomDAGContext.axis_len` (there is at least one axis, but arbitrarily more may be present). """ + if not rdagc.use_numpy and not rdagc.allow_duplicate_nodes: + def dedup(expr: Array) -> Array: + return pt.transform._verify_is_array(pt.transform.Deduplicator()(expr)) + + else: + def dedup(expr: Array) -> Array: + return expr + rng = rdagc.rng result = make_random_dag_inner(rdagc) @@ -248,14 +268,15 @@ def make_random_dag(rdagc: RandomDAGContext) -> Any: subscript[rng.integers(0, result.ndim)] = int( rng.integers(0, rdagc.axis_len)) - return result[tuple(subscript)] + return dedup(result[tuple(subscript)]) elif v == 1: # reduce away an axis # FIXME do reductions other than sum? - return rdagc.np.sum( - result, axis=int(rng.integers(0, result.ndim))) + return dedup( + rdagc.np.sum( + result, axis=int(rng.integers(0, result.ndim)))) else: raise AssertionError() @@ -275,7 +296,8 @@ def get_random_pt_dag(seed: int, Sequence[tuple[int, Callable[[RandomDAGContext], Array]]] | None) = None, axis_len: int = 4, - convert_dws_to_placeholders: bool = False + convert_dws_to_placeholders: bool = False, + allow_duplicate_nodes: bool = False ) -> pt.DictOfNamedArrays: if additional_generators is None: additional_generators = [] @@ -286,6 +308,7 @@ def get_random_pt_dag(seed: int, rdagc_comm = RandomDAGContext(np.random.default_rng(seed=seed), axis_len=axis_len, use_numpy=False, + allow_duplicate_nodes=allow_duplicate_nodes, additional_generators=additional_generators) dag = pt.make_dict_of_named_arrays({"result": make_random_dag(rdagc_comm)}) From 32c65918ce65a447648d97ea3264bfd27d139a06 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 27 Aug 2024 16:10:44 -0500 Subject: [PATCH 159/178] fix some more tests --- pytato/utils.py | 2 ++ test/test_codegen.py | 17 ++++++++++------- test/test_distributed.py | 5 +++-- test/test_pytato.py | 10 ++++++---- 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/pytato/utils.py b/pytato/utils.py index 31247897d..77cecc3bd 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -340,8 +340,10 @@ def are_shape_components_equal( if isinstance(dim1, INT_CLASSES) and isinstance(dim2, INT_CLASSES): return dim1 == dim2 + from pytato.transform import Deduplicator dim1_minus_dim2 = dim1 - dim2 assert isinstance(dim1_minus_dim2, Array) + dim1_minus_dim2 = Deduplicator()(dim1_minus_dim2) from pytato.transform import InputGatherer inputs = InputGatherer()(dim1_minus_dim2) diff --git a/test/test_codegen.py b/test/test_codegen.py index 2193b7fc9..a4ca538de 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -926,7 +926,7 @@ def _get_x_shape(_m, n_): x = pt.make_data_wrapper(x_in, shape=_get_x_shape(m, n)) np_out = np.einsum("ij, j -> i", A_in, x_in) - pt_expr = pt.einsum("ij, j -> i", A, x) + pt_expr = pt.transform.Deduplicator()(pt.einsum("ij, j -> i", A, x)) _, (pt_out,) = pt.generate_loopy(pt_expr)(cq, m=m_in, n=n_in) @@ -1582,8 +1582,9 @@ def get_np_input_args(): np_inputs = get_np_input_args() np_result = kernel(np, **np_inputs) - pt_dag = kernel(pt, **{kw: pt.make_data_wrapper(arg) - for kw, arg in np_inputs.items()}) + pt_dag = pt.transform.Deduplicator()( + kernel(pt, **{kw: pt.make_data_wrapper(arg) + for kw, arg in np_inputs.items()})) knl = pt.generate_loopy(pt_dag, options=lp.Options(write_code=True)) @@ -1939,10 +1940,12 @@ def build_expression(tracer): "baz": 65 * twice_x, "quux": 7 * twice_x_2} - result_with_functions = pt.tag_all_calls_to_be_inlined( - pt.make_dict_of_named_arrays(build_expression(pt.trace_call))) - result_without_functions = pt.make_dict_of_named_arrays( - build_expression(lambda fn, *args: fn(*args))) + expr = pt.transform.Deduplicator()( + pt.make_dict_of_named_arrays(build_expression(pt.trace_call))) + + result_with_functions = pt.tag_all_calls_to_be_inlined(expr) + result_without_functions = pt.transform.Deduplicator()( + pt.make_dict_of_named_arrays(build_expression(lambda fn, *args: fn(*args)))) # test that visualizing graphs with functions works dot = pt.get_dot_graph(result_with_functions) diff --git a/test/test_distributed.py b/test/test_distributed.py index d78479e08..65214c4b0 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -555,12 +555,13 @@ def _test_dag_with_multiple_send_nodes_per_sent_array_inner(ctx_factory): x_np = rng.random((10, 4)) x = pt.make_data_wrapper(cla.to_device(queue, x_np)) y = 2 * x + ones = pt.ones(10) send1 = pt.staple_distributed_send( y, dest_rank=1, comm_tag=42, - stapled_to=pt.ones(10)) + stapled_to=ones) send2 = pt.staple_distributed_send( y, dest_rank=2, comm_tag=42, - stapled_to=pt.ones(10)) + stapled_to=ones) z = 4 * y dag = pt.make_dict_of_named_arrays({"z": z, "send1": send1, "send2": send2}) else: diff --git a/test/test_pytato.py b/test/test_pytato.py index e0e1f90d5..7fe5a3b49 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -724,7 +724,7 @@ def test_small_dag_with_duplicates_count(): # Check that duplicates are correctly calculated assert node_count - num_duplicates == len( - pt.transform.DependencyMapper()(dag)) + pt.transform.DependencyMapper(err_on_collision=False)(dag)) assert node_count - num_duplicates == get_num_nodes( dag, count_duplicates=False) @@ -761,7 +761,7 @@ def test_large_dag_with_duplicates_count(): # Check that duplicates are correctly calculated assert node_count - num_duplicates == len( - pt.transform.DependencyMapper()(dag)) + pt.transform.DependencyMapper(err_on_collision=False)(dag)) assert node_count - num_duplicates == get_num_nodes( dag, count_duplicates=False) @@ -806,6 +806,8 @@ def post_visit(self, expr): assert expr.name == "x" expr, inp = construct_intestine_graph() + expr = pt.transform.Deduplicator()(expr) + result = pt.transform.rec_get_user_nodes(expr, inp) SubexprRecorder()(expr) @@ -1029,7 +1031,7 @@ def test_created_at(): old_tag = tag - res1_new = pt.transform.map_and_copy(res1, lambda x: x) + res1_new = pt.transform.Deduplicator()(res1) created_tag = frozenset({tag for tag in res1_new.non_equality_tags @@ -1167,7 +1169,7 @@ class ExistentTag(Tag): out = make_random_dag(rdagc_pt).tagged(ExistentTag()) - dag = pt.make_dict_of_named_arrays({"out": out}) + dag = pt.transform.Deduplicator()(pt.make_dict_of_named_arrays({"out": out})) # get_num_nodes() returns an extra DictOfNamedArrays node assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag) From 00aed84b9a9cbc67a8af4e8bf12d49375303ace9 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 23 Sep 2024 22:51:56 -0500 Subject: [PATCH 160/178] don't check for collisions in ArrayToDotNodeInfoMapper --- pytato/visualization/dot.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index c0c3e7945..7420d1708 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -178,9 +178,10 @@ def stringify_shape(shape: ShapeType) -> str: return "(" + ", ".join(components) + ")" +# FIXME: Make this inherit from CachedWalkMapper instead? class ArrayToDotNodeInfoMapper(CachedMapper[None, None, []]): def __init__(self) -> None: - super().__init__() + super().__init__(err_on_collision=False) self.node_to_dot: dict[ArrayOrNames, _DotNodeInfo] = {} self.functions: set[FunctionDefinition] = set() From 02422dd591034678819d853b0708afc931fc6b27 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 18 Sep 2024 15:40:25 -0500 Subject: [PATCH 161/178] avoid duplication in MPMSMaterializer now inherits from CachedMapper --- pytato/transform/__init__.py | 319 +++++++++++++++++++++++++---------- 1 file changed, 231 insertions(+), 88 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index c87435672..0330f25ac 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -539,6 +539,31 @@ def clone_for_callee( # {{{ TransformMapper +def _is_mapper_created_duplicate(expr: CacheExprT, result: CacheExprT) -> bool: + """Returns *True* if *result* is not identical to *expr* when it ought to be.""" + from pytato.analysis import DirectPredecessorsGetter + pred_getter = DirectPredecessorsGetter(include_functions=True) + return ( + hash(result) == hash(expr) + and result == expr + and result is not expr + # Only consider "direct" duplication, not duplication resulting from + # equality-preserving changes to predecessors. Assume that such changes are + # OK, otherwise they would have been detected at the point at which they + # originated. (For example, consider a DAG containing pre-existing + # duplicates. If a subexpression of *expr* is a duplicate and is replaced + # with a previously encountered version from the cache, a new instance of + # *expr* must be created. This should not trigger an error.) + and all( + result_pred is pred + for pred, result_pred in zip( + # type-ignore-reason: mypy doesn't seem to recognize overloaded + # Mapper.__call__ here + pred_getter(expr), # type: ignore[arg-type] + pred_getter(result), # type: ignore[arg-type] + strict=True))) + + class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, P]): """ Cache for :class:`TransformMapper` and :class:`TransformMapperWithExtraArgs`. @@ -582,31 +607,10 @@ def add( try: result = self._result_to_cached_result[result] except KeyError: - if self.err_on_created_duplicate: - from pytato.analysis import DirectPredecessorsGetter - pred_getter = DirectPredecessorsGetter(include_functions=True) - if ( - hash(result) == hash(inputs.expr) - and result == inputs.expr - and result is not inputs.expr - # Only consider "direct" duplication, not duplication - # resulting from equality-preserving changes to predecessors. - # Assume that such changes are OK, otherwise they would have - # been detected at the point at which they originated. (For - # example, consider a DAG containing pre-existing duplicates. - # If a subexpression of *expr* is a duplicate and is replaced - # with a previously encountered version from the cache, a - # new instance of *expr* must be created. This should not - # trigger an error.) - and all( - result_pred is pred - for pred, result_pred in zip( - # type-ignore-reason: mypy doesn't seem to recognize - # overloaded Mapper.__call__ here - pred_getter(inputs.expr), # type: ignore[arg-type] - pred_getter(result), # type: ignore[arg-type] - strict=True))): - raise MapperCreatedDuplicateError from None + if ( + self.err_on_created_duplicate + and _is_mapper_created_duplicate(inputs.expr, result)): + raise MapperCreatedDuplicateError from None self._result_to_cached_result[result] = result @@ -2043,6 +2047,65 @@ class MPMSMaterializerAccumulator: expr: Array +class MPMSMaterializerCache( + CachedMapperCache[ArrayOrNames, MPMSMaterializerAccumulator, []]): + """ + Cache for :class:`MPMSMaterializer`. + + .. automethod:: __init__ + .. automethod:: add + """ + def __init__( + self, + err_on_collision: bool, + err_on_created_duplicate: bool) -> None: + """ + Initialize the cache. + + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. + :arg err_on_created_duplicate: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + super().__init__(err_on_collision=err_on_collision) + + self.err_on_created_duplicate = err_on_created_duplicate + + self._result_key_to_result: dict[ + ArrayOrNames, MPMSMaterializerAccumulator] = {} + + def add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, []], + result: MPMSMaterializerAccumulator) -> MPMSMaterializerAccumulator: + """ + Cache a mapping result. + + Returns the cached result (which may not be identical to *result* if a + result was already cached with the same result key). + """ + key = inputs.key + + assert key not in self._input_key_to_result, \ + f"Cache entry is already present for key '{key}'." + + try: + result = self._result_key_to_result[result.expr] + except KeyError: + if ( + self.err_on_created_duplicate + and _is_mapper_created_duplicate(inputs.expr, result.expr)): + raise MapperCreatedDuplicateError from None + + self._result_key_to_result[result.expr] = result + + self._input_key_to_result[key] = result + if self.err_on_collision: + self._input_key_to_expr[key] = inputs.expr + + return result + + def _materialize_if_mpms(expr: Array, nsuccessors: int, predecessors: Iterable[MPMSMaterializerAccumulator] @@ -2060,13 +2123,16 @@ def _materialize_if_mpms(expr: Array, for pred in predecessors), frozenset()) if nsuccessors > 1 and len(materialized_predecessors) > 1: - new_expr = expr.tagged(ImplStored()) + if not expr.tags_of_type(ImplStored): + new_expr = expr.tagged(ImplStored()) + else: + new_expr = expr return MPMSMaterializerAccumulator(frozenset([new_expr]), new_expr) else: return MPMSMaterializerAccumulator(materialized_predecessors, expr) -class MPMSMaterializer(Mapper[MPMSMaterializerAccumulator, Never, []]): +class MPMSMaterializer(CachedMapper[MPMSMaterializerAccumulator, Never, []]): """ See :func:`materialize_with_mpms` for an explanation. @@ -2075,17 +2141,41 @@ class MPMSMaterializer(Mapper[MPMSMaterializerAccumulator, Never, []]): A mapping from a node in the expression graph (i.e. an :class:`~pytato.Array`) to its number of successors. """ - def __init__(self, nsuccessors: Mapping[Array, int]): - super().__init__() + def __init__( + self, + nsuccessors: Mapping[Array, int], + _cache: MPMSMaterializerCache | None = None): + err_on_collision = False + err_on_created_duplicate = False + + if _cache is None: + _cache = MPMSMaterializerCache( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate) + + # Does not support functions, so function_cache is ignored + super().__init__(err_on_collision=err_on_collision, _cache=_cache) + self.nsuccessors = nsuccessors - self.cache: dict[ArrayOrNames, MPMSMaterializerAccumulator] = {} - def rec(self, expr: ArrayOrNames) -> MPMSMaterializerAccumulator: - if expr in self.cache: - return self.cache[expr] - result: MPMSMaterializerAccumulator = super().rec(expr) - self.cache[expr] = result - return result + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, []], + result: MPMSMaterializerAccumulator) -> MPMSMaterializerAccumulator: + try: + return self._cache.add(inputs, result) + except MapperCreatedDuplicateError as e: + raise ValueError( + f"no-op duplication detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def clone_for_callee( + self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + raise AssertionError("Control shouldn't reach this point.") def _map_input_base(self, expr: InputArgumentBase ) -> MPMSMaterializerAccumulator: @@ -2102,24 +2192,40 @@ def map_named_array(self, expr: NamedArray) -> MPMSMaterializerAccumulator: def map_index_lambda(self, expr: IndexLambda) -> MPMSMaterializerAccumulator: children_rec = {bnd_name: self.rec(bnd) for bnd_name, bnd in sorted(expr.bindings.items())} + new_children: Mapping[str, Array] = immutabledict({ + bnd_name: bnd.expr + for bnd_name, bnd in sorted(children_rec.items())}) + + if ( + frozenset(new_children.keys()) == frozenset(expr.bindings.keys()) + and all( + new_children[name] is expr.bindings[name] + for name in expr.bindings)): + new_expr = expr + else: + new_expr = IndexLambda( + expr=expr.expr, + shape=expr.shape, + dtype=expr.dtype, + bindings=new_children, + axes=expr.axes, + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) - new_expr = IndexLambda(expr=expr.expr, - shape=expr.shape, - dtype=expr.dtype, - bindings=immutabledict({bnd_name: bnd.expr - for bnd_name, bnd in sorted(children_rec.items())}), - axes=expr.axes, - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], children_rec.values()) def map_stack(self, expr: Stack) -> MPMSMaterializerAccumulator: rec_arrays = [self.rec(ary) for ary in expr.arrays] - new_expr = Stack(tuple(ary.expr for ary in rec_arrays), - expr.axis, axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_arrays = tuple(ary.expr for ary in rec_arrays) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + new_expr = expr + else: + new_expr = Stack(new_arrays, expr.axis, axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -2127,29 +2233,44 @@ def map_stack(self, expr: Stack) -> MPMSMaterializerAccumulator: def map_concatenate(self, expr: Concatenate) -> MPMSMaterializerAccumulator: rec_arrays = [self.rec(ary) for ary in expr.arrays] - new_expr = Concatenate(tuple(ary.expr for ary in rec_arrays), - expr.axis, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_arrays = tuple(ary.expr for ary in rec_arrays) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + new_expr = expr + else: + new_expr = Concatenate(new_arrays, + expr.axis, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + return _materialize_if_mpms(new_expr, self.nsuccessors[expr], rec_arrays) def map_roll(self, expr: Roll) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) - new_expr = Roll(rec_array.expr, expr.shift, expr.axis, axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if rec_array.expr is expr.array: + new_expr = expr + else: + new_expr = Roll(rec_array.expr, expr.shift, expr.axis, axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) def map_axis_permutation(self, expr: AxisPermutation ) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) - new_expr = AxisPermutation(rec_array.expr, expr.axis_permutation, - axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if rec_array.expr is expr.array: + new_expr = expr + else: + new_expr = AxisPermutation(rec_array.expr, expr.axis_permutation, + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) @@ -2159,16 +2280,23 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: rec_indices = {i: self.rec(idx) for i, idx in enumerate(expr.indices) if isinstance(idx, Array)} - - new_expr = type(expr)(rec_array.expr, - tuple(rec_indices[i].expr - if i in rec_indices - else expr.indices[i] - for i in range( - len(expr.indices))), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_indices = tuple(rec_indices[i].expr + if i in rec_indices + else expr.indices[i] + for i in range( + len(expr.indices))) + if ( + rec_array.expr is expr.array + and all( + new_idx is idx + for idx, new_idx in zip(expr.indices, new_indices, strict=True))): + new_expr = expr + else: + new_expr = type(expr)(rec_array.expr, + new_indices, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -2181,26 +2309,35 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: def map_reshape(self, expr: Reshape) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) - new_expr = Reshape(rec_array.expr, expr.newshape, - expr.order, axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if rec_array.expr is expr.array: + new_expr = expr + else: + new_expr = Reshape(rec_array.expr, expr.newshape, + expr.order, axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) def map_einsum(self, expr: Einsum) -> MPMSMaterializerAccumulator: - rec_arrays = [self.rec(ary) for ary in expr.args] - new_expr = Einsum(expr.access_descriptors, - tuple(ary.expr for ary in rec_arrays), - expr.redn_axis_to_redn_descr, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + rec_args = [self.rec(ary) for ary in expr.args] + new_args = tuple(ary.expr for ary in rec_args) + if all( + new_arg is arg + for arg, new_arg in zip(expr.args, new_args, strict=True)): + new_expr = expr + else: + new_expr = Einsum(expr.access_descriptors, + new_args, + expr.redn_axis_to_redn_descr, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], - rec_arrays) + rec_args) def map_dict_of_named_arrays(self, expr: DictOfNamedArrays ) -> MPMSMaterializerAccumulator: @@ -2213,15 +2350,21 @@ def map_loopy_call_result(self, expr: NamedArray) -> MPMSMaterializerAccumulator def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder ) -> MPMSMaterializerAccumulator: - rec_passthrough = self.rec(expr.passthrough_data) rec_send_data = self.rec(expr.send.data) - new_expr = DistributedSendRefHolder( - send=DistributedSend(rec_send_data.expr, - dest_rank=expr.send.dest_rank, - comm_tag=expr.send.comm_tag, - tags=expr.send.tags), - passthrough_data=rec_passthrough.expr, - ) + if rec_send_data.expr is expr.send.data: + new_send = expr.send + else: + new_send = DistributedSend( + rec_send_data.expr, + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag, + tags=expr.send.tags) + rec_passthrough = self.rec(expr.passthrough_data) + if new_send is expr.send and rec_passthrough.expr is expr.passthrough_data: + new_expr = expr + else: + new_expr = DistributedSendRefHolder(new_send, rec_passthrough.expr) + return MPMSMaterializerAccumulator( rec_passthrough.materialized_predecessors, new_expr) From 888451a23c3ab38522de0c68fc3b9bf23dcbfea4 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 6 Feb 2025 21:43:35 -0600 Subject: [PATCH 162/178] avoid duplicates in EinsumWithNoBroadcastsRewriter --- pytato/transform/remove_broadcasts_einsum.py | 110 +++++++++++++++---- 1 file changed, 91 insertions(+), 19 deletions(-) diff --git a/pytato/transform/remove_broadcasts_einsum.py b/pytato/transform/remove_broadcasts_einsum.py index 2d8f7e0f0..50ee4967c 100644 --- a/pytato/transform/remove_broadcasts_einsum.py +++ b/pytato/transform/remove_broadcasts_einsum.py @@ -28,46 +28,118 @@ THE SOFTWARE. """ -from typing import cast +from typing import TYPE_CHECKING, cast from pytato.array import Array, Einsum, EinsumAxisDescriptor -from pytato.transform import CopyMapper, MappedT, _verify_is_array +from pytato.transform import ( + ArrayOrNames, + CacheKeyT, + CopyMapperWithExtraArgs, + MappedT, + Mapper, + _verify_is_array, +) from pytato.utils import are_shape_components_equal -class EinsumWithNoBroadcastsRewriter(CopyMapper): - def map_einsum(self, expr: Einsum) -> Array: +if TYPE_CHECKING: + from pytato.function import FunctionDefinition + + +class EinsumWithNoBroadcastsRewriter(CopyMapperWithExtraArgs[[tuple[int, ...] | None]]): + def get_cache_key( + self, + expr: ArrayOrNames, + axes_to_squeeze: tuple[int, ...] | None = None + ) -> CacheKeyT: + return (expr, axes_to_squeeze) + + def get_function_definition_cache_key( + self, + expr: FunctionDefinition, + axes_to_squeeze: tuple[int, ...] | None = None + ) -> CacheKeyT: + assert axes_to_squeeze is None + return expr + + def _squeeze_axes( + self, + expr: Array, + axes_to_squeeze: tuple[int, ...] | None = None) -> Array: + result = ( + expr[ + tuple( + slice(None) if idim not in axes_to_squeeze else 0 + for idim in range(expr.ndim))] + if axes_to_squeeze else expr) + return result + + def rec( + self, + expr: ArrayOrNames, + axes_to_squeeze: tuple[int, ...] | None = None) -> ArrayOrNames: + inputs = self._make_cache_inputs(expr, axes_to_squeeze) + try: + return self._cache_retrieve(inputs) + except KeyError: + rec_result: ArrayOrNames = Mapper.rec(self, expr, None) + result: ArrayOrNames + if isinstance(expr, Array): + result = self._squeeze_axes( + _verify_is_array(rec_result), + axes_to_squeeze) + else: + result = rec_result + return self._cache_add(inputs, result) + + def map_einsum( + self, expr: Einsum, axes_to_squeeze: tuple[int, ...] | None) -> Array: new_args: list[Array] = [] new_access_descriptors: list[tuple[EinsumAxisDescriptor, ...]] = [] descr_to_axis_len = expr._access_descr_to_axis_len() - for acc_descrs, arg in zip(expr.access_descriptors, expr.args, strict=True): - arg = _verify_is_array(self.rec(arg)) - axes_to_squeeze: list[int] = [] + for arg, acc_descrs in zip(expr.args, expr.access_descriptors, strict=True): + axes_to_squeeze_list: list[int] = [] for idim, acc_descr in enumerate(acc_descrs): if not are_shape_components_equal(arg.shape[idim], descr_to_axis_len[acc_descr]): assert are_shape_components_equal(arg.shape[idim], 1) - axes_to_squeeze.append(idim) + axes_to_squeeze_list.append(idim) + axes_to_squeeze = tuple(axes_to_squeeze_list) if axes_to_squeeze: - arg = arg[tuple(slice(None) if idim not in axes_to_squeeze else 0 - for idim in range(arg.ndim))] - acc_descrs = tuple(acc_descr + new_arg = _verify_is_array(self.rec(arg, axes_to_squeeze)) + new_acc_descrs = tuple(acc_descr for idim, acc_descr in enumerate(acc_descrs) if idim not in axes_to_squeeze) + else: + new_arg = _verify_is_array(self.rec(arg)) + new_acc_descrs = acc_descrs - new_args.append(arg) - new_access_descriptors.append(acc_descrs) + new_args.append(new_arg) + new_access_descriptors.append(new_acc_descrs) assert len(new_args) == len(expr.args) assert len(new_access_descriptors) == len(expr.access_descriptors) - return Einsum(tuple(new_access_descriptors), - tuple(new_args), - expr.redn_axis_to_redn_descr, - tags=expr.tags, - axes=expr.axes,) + if ( + all( + new_arg is arg + for arg, new_arg in zip(expr.args, new_args, strict=True)) + and all( + new_acc_descr is acc_descr + for acc_descr, new_acc_descr in zip( + expr.access_descriptors, + new_access_descriptors, + strict=True))): + return expr + else: + return Einsum(tuple(new_access_descriptors), + tuple(new_args), + axes=expr.axes, + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def rewrite_einsums_with_no_broadcasts(expr: MappedT) -> MappedT: @@ -97,6 +169,6 @@ def rewrite_einsums_with_no_broadcasts(expr: MappedT) -> MappedT: alter its value. """ mapper = EinsumWithNoBroadcastsRewriter() - return cast("MappedT", mapper(expr)) + return cast("MappedT", mapper(expr, None)) # vim:fdm=marker From 3b8bd2dd61f454f1ed7785439450020c960dbd10 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 6 Feb 2025 22:01:19 -0600 Subject: [PATCH 163/178] forbid DependencyMapper from being called on functions --- pytato/transform/__init__.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 0330f25ac..df89f4f1d 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1519,7 +1519,7 @@ def map_named_call_result(self, expr: NamedCallResult) -> ResultT: # {{{ DependencyMapper -class DependencyMapper(CombineMapper[R, R]): +class DependencyMapper(CombineMapper[R, Never]): """ Maps a :class:`pytato.array.Array` to a :class:`frozenset` of :class:`pytato.array.Array`'s it depends on. @@ -1581,14 +1581,10 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> R: return self.combine(frozenset([expr]), super().map_distributed_recv(expr)) - def map_function_definition(self, expr: FunctionDefinition) -> R: + def map_call(self, expr: Call) -> R: # do not include arrays from the function's body as it would involve # putting arrays from different namespaces into the same collection. - return frozenset() - - def map_call(self, expr: Call) -> R: - return self.combine(self.rec_function_definition(expr.function), - *[self.rec(bnd) for bnd in expr.bindings.values()]) + return self.combine(*[self.rec(bnd) for bnd in expr.bindings.values()]) def map_named_call_result(self, expr: NamedCallResult) -> R: return self.rec(expr._container) From 8bc1d767a66e9d753d3982a75860e348fd22789e Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 27 Feb 2025 17:59:51 -0600 Subject: [PATCH 164/178] deduplicate in advection example --- examples/advection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/advection.py b/examples/advection.py index 339ff80a8..fd308ae50 100755 --- a/examples/advection.py +++ b/examples/advection.py @@ -156,6 +156,7 @@ def test_advection_convergence(order, flux_type): op = AdvectionOperator(discr, c=1, flux_type=flux_type, dg_ops=dg_ops) result = op.apply(u) + result = pt.transform.Deduplicator()(result) prog = pt.generate_loopy(result, cl_device=queue.device) From 1db66961765c4fdb4b61bed5bc26b6a1e8ab538c Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 14 Nov 2024 11:27:02 -0800 Subject: [PATCH 165/178] fix docstring of NodeCountMapper --- pytato/analysis/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 019fbd846..0b5a48fc0 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -437,9 +437,14 @@ class NodeCountMapper(CachedWalkMapper[[]]): Counts the number of nodes of a given type in a DAG. .. autoattribute:: expr_type_counts + + Dictionary mapping node types to number of nodes of that type. + .. autoattribute:: count_duplicates - Dictionary mapping node types to number of nodes of that type. + If `True`, counts each array instance as a separate node, even if some are + equal. + """ def __init__( From c9891acd8683beb71a71f7d4c18699e1f2c97828 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 20 Aug 2024 16:30:30 -0500 Subject: [PATCH 166/178] add NodeT in analysis --- pytato/analysis/__init__.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 0b5a48fc0..a9030d401 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -59,6 +59,9 @@ from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.loopy import LoopyCall + +NodeT = Array | FunctionDefinition + __doc__ = """ .. currentmodule:: pytato.analysis @@ -455,7 +458,7 @@ def __init__( super().__init__(_visited_functions=_visited_functions) from collections import defaultdict - self.expr_type_counts: dict[type[Any], int] = defaultdict(int) + self.expr_type_counts: dict[type[NodeT], int] = defaultdict(int) self.count_duplicates = count_duplicates def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames: @@ -473,14 +476,14 @@ def clone_for_callee(self, function: FunctionDefinition) -> Self: _visited_functions=self._visited_functions) def post_visit(self, expr: Any) -> None: - if not isinstance(expr, DictOfNamedArrays): + if isinstance(expr, NodeT): self.expr_type_counts[type(expr)] += 1 def get_node_type_counts( outputs: Array | DictOfNamedArrays, count_duplicates: bool = False - ) -> dict[type[Any], int]: + ) -> dict[type[NodeT], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. @@ -541,7 +544,7 @@ def __init__(self, _visited_functions: set[Any] | None = None) -> None: super().__init__(_visited_functions=_visited_functions) from collections import defaultdict - self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int) + self.expr_multiplicity_counts: dict[NodeT, int] = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> int: # Returns each node, including nodes that are duplicates @@ -552,12 +555,12 @@ def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: return id(expr) def post_visit(self, expr: Any) -> None: - if not isinstance(expr, DictOfNamedArrays): + if isinstance(expr, NodeT): self.expr_multiplicity_counts[expr] += 1 def get_node_multiplicities( - outputs: Array | DictOfNamedArrays) -> dict[Array, int]: + outputs: Array | DictOfNamedArrays) -> dict[NodeT, int]: """ Returns the multiplicity per `expr`. """ From f4caafc54d5136ca618a559885fbb918963c49b5 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 20 Aug 2024 16:31:09 -0500 Subject: [PATCH 167/178] support function traversal in NodeCountMapper and NodeMultiplicityMapper --- pytato/analysis/__init__.py | 43 +++++++++++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index a9030d401..5ee3a9994 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -60,7 +60,8 @@ from pytato.loopy import LoopyCall -NodeT = Array | FunctionDefinition +# FIXME: Think about whether this makes sense +NodeT = Array | FunctionDefinition | Call __doc__ = """ .. currentmodule:: pytato.analysis @@ -479,6 +480,19 @@ def post_visit(self, expr: Any) -> None: if isinstance(expr, NodeT): self.expr_type_counts[type(expr)] += 1 + def map_function_definition(self, expr: FunctionDefinition) -> None: + if not self.visit(expr): + return + + new_mapper = self.clone_for_callee(expr) + for ret in expr.returns.values(): + new_mapper(ret) + + for node_type, count in new_mapper.expr_type_counts.items(): + self.expr_type_counts[node_type] += count + + self.post_visit(expr) + def get_node_type_counts( outputs: Array | DictOfNamedArrays, @@ -540,9 +554,14 @@ class NodeMultiplicityMapper(CachedWalkMapper[[]]): .. autoattribute:: expr_multiplicity_counts """ - def __init__(self, _visited_functions: set[Any] | None = None) -> None: + def __init__( + self, + traverse_functions: bool = True, + _visited_functions: set[Any] | None = None) -> None: super().__init__(_visited_functions=_visited_functions) + self.traverse_functions = traverse_functions + from collections import defaultdict self.expr_multiplicity_counts: dict[NodeT, int] = defaultdict(int) @@ -554,10 +573,26 @@ def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: # Returns each node, including nodes that are duplicates return id(expr) + def visit(self, expr: Any) -> bool: + return not isinstance(expr, FunctionDefinition) or self.traverse_functions + def post_visit(self, expr: Any) -> None: if isinstance(expr, NodeT): self.expr_multiplicity_counts[expr] += 1 + def map_function_definition(self, expr: FunctionDefinition) -> None: + if not self.visit(expr): + return + + new_mapper = self.clone_for_callee(expr) + for ret in expr.returns.values(): + new_mapper(ret) + + for subexpr, count in new_mapper.expr_multiplicity_counts.items(): + self.expr_multiplicity_counts[subexpr] += count + + self.post_visit(expr) + def get_node_multiplicities( outputs: Array | DictOfNamedArrays) -> dict[NodeT, int]: @@ -602,8 +637,8 @@ def map_function_definition(self, expr: FunctionDefinition) -> None: return new_mapper = self.clone_for_callee(expr) - for subexpr in expr.returns.values(): - new_mapper(subexpr) + for ret in expr.returns.values(): + new_mapper(ret) self.count += new_mapper.count self.post_visit(expr) From fc8d7623a33f3111511762551a91769834fcdb57 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 20 Aug 2024 15:14:41 -0500 Subject: [PATCH 168/178] add get_num_node_instances --- pytato/analysis/__init__.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5ee3a9994..6970f5596 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -539,6 +539,31 @@ def get_num_nodes( return sum(ncm.expr_type_counts.values()) + +def get_num_node_instances( + outputs: Array | DictOfNamedArrays, + node_type: type[NodeT], + strict: bool = True, + count_duplicates: bool = False) -> int: + """ + Returns the number of nodes in DAG *outputs* that have type *node_type* (if + *strict* is `True`) or are instances of *node_type* (if *strict* is `False`). + """ + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + ncm = NodeCountMapper(count_duplicates) + ncm(outputs) + + if strict: + return ncm.expr_type_counts[node_type] + else: + return sum( + count + for other_node_type, count in ncm.expr_type_counts.items() + if isinstance(other_node_type, node_type)) + # }}} From 55bdf0489cc0de94d7e103023d45be2a6745f910 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 12 Mar 2024 09:25:59 -0500 Subject: [PATCH 169/178] add collect_nodes_of_type --- pytato/analysis/__init__.py | 70 +++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 6970f5596..0e895b461 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -78,6 +78,8 @@ .. autofunction:: get_num_call_sites +.. autofunction:: collect_nodes_of_type + .. autoclass:: DirectPredecessorsGetter .. autoclass:: TagCountMapper @@ -773,4 +775,72 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: # }}} + +# {{{ NodeCollector + +@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) +class NodeCollector(CachedWalkMapper[[]]): + """ + Collects all nodes of a given type in a DAG. + + .. attribute:: nodes + + The collected nodes. + """ + + def __init__( + self, + node_type: type[NodeT], + traverse_functions: bool = True) -> None: + super().__init__() + self.node_type = node_type + self.traverse_functions = traverse_functions + self.nodes: set[NodeT] = set() + + def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: + return expr + + def get_function_definition_cache_key( + self, expr: FunctionDefinition) -> FunctionDefinition: + return expr + + def clone_for_callee( + self: NodeCollector, function: FunctionDefinition) -> NodeCollector: + return type(self)(self.node_type) + + def visit(self, expr: Any) -> bool: + return not isinstance(expr, FunctionDefinition) or self.traverse_functions + + def post_visit(self, expr: Any) -> None: + if isinstance(expr, self.node_type): + self.nodes.add(expr) + + def map_function_definition(self, expr: FunctionDefinition) -> None: + if not self.visit(expr): + return + + new_mapper = self.clone_for_callee(expr) + for ret in expr.returns.values(): + new_mapper(ret) + + self.nodes |= new_mapper.nodes + + self.post_visit(expr) + + +def collect_nodes_of_type( + outputs: Array | DictOfNamedArrays, + node_type: type[NodeT]) -> set[NodeT]: + """Returns the nodes that are instances of *node_type* in DAG *outputs*.""" + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + nc = NodeCollector(node_type) + nc(outputs) + + return nc.nodes + +# }}} + # vim: fdm=marker From ebf8562d8ed9d274b88164107ff2c8aa598c3f99 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 11 Jun 2024 14:27:51 -0500 Subject: [PATCH 170/178] generalize NodeCollector --- pytato/analysis/__init__.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 0e895b461..e6331f817 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -52,7 +52,7 @@ if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + from collections.abc import Callable, Iterable, Mapping import pytools @@ -778,10 +778,11 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: # {{{ NodeCollector +# FIXME: Decide if this should be a CombineMapper instead? @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) class NodeCollector(CachedWalkMapper[[]]): """ - Collects all nodes of a given type in a DAG. + Collects all nodes matching specified criteria in a DAG. .. attribute:: nodes @@ -790,10 +791,10 @@ class NodeCollector(CachedWalkMapper[[]]): def __init__( self, - node_type: type[NodeT], + collect_func: Callable[[NodeT], bool], traverse_functions: bool = True) -> None: super().__init__() - self.node_type = node_type + self.collect_func = collect_func self.traverse_functions = traverse_functions self.nodes: set[NodeT] = set() @@ -806,13 +807,13 @@ def get_function_definition_cache_key( def clone_for_callee( self: NodeCollector, function: FunctionDefinition) -> NodeCollector: - return type(self)(self.node_type) + return type(self)(self.collect_func) def visit(self, expr: Any) -> bool: return not isinstance(expr, FunctionDefinition) or self.traverse_functions def post_visit(self, expr: Any) -> None: - if isinstance(expr, self.node_type): + if isinstance(expr, NodeT) and self.collect_func(expr): self.nodes.add(expr) def map_function_definition(self, expr: FunctionDefinition) -> None: @@ -830,16 +831,19 @@ def map_function_definition(self, expr: FunctionDefinition) -> None: def collect_nodes_of_type( outputs: Array | DictOfNamedArrays, - node_type: type[NodeT]) -> set[NodeT]: + node_type: type[NodeT]) -> frozenset[NodeT]: """Returns the nodes that are instances of *node_type* in DAG *outputs*.""" - from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) - nc = NodeCollector(node_type) + def collect_func(expr: NodeT) -> bool: + return isinstance(expr, node_type) + + nc = NodeCollector(collect_func) nc(outputs) - return nc.nodes + return frozenset(nc.nodes) + # }}} From e264f5b2165cadb09ccbb1bce83ce8adcfe32a89 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 11 Jun 2024 14:28:09 -0500 Subject: [PATCH 171/178] add collect_materialized_nodes --- pytato/analysis/__init__.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index e6331f817..42e7257f4 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -80,6 +80,8 @@ .. autofunction:: collect_nodes_of_type +.. autofunction:: collect_materialized_nodes + .. autoclass:: DirectPredecessorsGetter .. autoclass:: TagCountMapper @@ -845,6 +847,20 @@ def collect_func(expr: NodeT) -> bool: return frozenset(nc.nodes) +def collect_materialized_nodes( + outputs: Array | DictOfNamedArrays) -> frozenset[NodeT]: + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + def collect_func(expr: NodeT) -> bool: + from pytato.tags import ImplStored + return bool(expr.tags_of_type(ImplStored)) + + nc = NodeCollector(collect_func) + nc(outputs) + + return frozenset(nc.nodes) + # }}} # vim: fdm=marker From 4192952b74b7e9abbdb6e3cca9f03abac31d697a Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 5 Sep 2024 11:14:28 -0500 Subject: [PATCH 172/178] cosmetic change --- pytato/analysis/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 42e7257f4..36240de0c 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -661,6 +661,10 @@ def get_cache_key(self, expr: ArrayOrNames) -> int: def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: return id(expr) + def post_visit(self, expr: Any) -> None: + if isinstance(expr, Call): + self.count += 1 + def map_function_definition(self, expr: FunctionDefinition) -> None: if not self.visit(expr): return @@ -672,10 +676,6 @@ def map_function_definition(self, expr: FunctionDefinition) -> None: self.post_visit(expr) - def post_visit(self, expr: Any) -> None: - if isinstance(expr, Call): - self.count += 1 - def get_num_call_sites(outputs: Array | DictOfNamedArrays) -> int: """Returns the number of nodes in DAG *outputs*.""" From ee5461a4cc95223fb27bec2b5ed8951ee601b2ef Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 24 Oct 2024 10:58:28 -0500 Subject: [PATCH 173/178] add option to not traverse functions in get_num_nodes --- pytato/analysis/__init__.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 36240de0c..22c698511 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -458,10 +458,13 @@ class NodeCountMapper(CachedWalkMapper[[]]): def __init__( self, count_duplicates: bool = False, + traverse_functions: bool = True, _visited_functions: set[Any] | None = None, ) -> None: super().__init__(_visited_functions=_visited_functions) + self.traverse_functions = traverse_functions + from collections import defaultdict self.expr_type_counts: dict[type[NodeT], int] = defaultdict(int) self.count_duplicates = count_duplicates @@ -478,8 +481,12 @@ def get_function_definition_cache_key( def clone_for_callee(self, function: FunctionDefinition) -> Self: return type(self)( count_duplicates=self.count_duplicates, + traverse_functions=self.traverse_functions, _visited_functions=self._visited_functions) + def visit(self, expr: Any) -> bool: + return not isinstance(expr, FunctionDefinition) or self.traverse_functions + def post_visit(self, expr: Any) -> None: if isinstance(expr, NodeT): self.expr_type_counts[type(expr)] += 1 @@ -520,7 +527,8 @@ def get_node_type_counts( def get_num_nodes( outputs: Array | DictOfNamedArrays, - count_duplicates: bool | None = None + count_duplicates: bool | None = None, + traverse_functions: bool = True ) -> int: """ Returns the number of nodes in DAG *outputs*. @@ -538,7 +546,9 @@ def get_num_nodes( from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) - ncm = NodeCountMapper(count_duplicates) + ncm = NodeCountMapper( + count_duplicates=count_duplicates, + traverse_functions=traverse_functions) ncm(outputs) return sum(ncm.expr_type_counts.values()) From 9e18a5233d1f1f7661f88d7e13adaeee69e46075 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 14 Nov 2024 11:27:48 -0800 Subject: [PATCH 174/178] add trace_dependencies --- pytato/analysis/__init__.py | 189 ++++++++++++++++++++++++++++++++++- pytato/transform/__init__.py | 3 + 2 files changed, 190 insertions(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 22c698511..ba4fa6ad7 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -36,7 +36,9 @@ from pytato.array import ( Array, + AxisPermutation, Concatenate, + DataWrapper, DictOfNamedArrays, Einsum, IndexBase, @@ -44,11 +46,21 @@ IndexRemappingBase, InputArgumentBase, NamedArray, + Placeholder, + Reshape, + Roll, ShapeType, + SizeParam, Stack, ) from pytato.function import Call, FunctionDefinition, NamedCallResult -from pytato.transform import ArrayOrNames, CachedWalkMapper, CombineMapper, Mapper +from pytato.transform import ( + ArrayOrNames, + CachedWalkMapper, + CombineMapper, + IndexOrShapeExpr, + Mapper, +) if TYPE_CHECKING: @@ -57,7 +69,7 @@ import pytools from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder - from pytato.loopy import LoopyCall + from pytato.loopy import LoopyCall, LoopyCallResult # FIXME: Think about whether this makes sense @@ -82,6 +94,8 @@ .. autofunction:: collect_materialized_nodes +.. autofunction:: trace_dependencies + .. autoclass:: DirectPredecessorsGetter .. autoclass:: TagCountMapper @@ -873,4 +887,175 @@ def collect_func(expr: NodeT) -> bool: # }}} + +# {{{ DependencyTracer + +class DependencyTracer(CombineMapper[frozenset[tuple[Array, ...]], Never]): + """ + Maps a DAG and a node to a :class:`frozenset` of `tuple`\\ s of + :class:`pytato.array.Array`\\ s representing dependency traces from + the node to one of the DAG outputs. + + .. note:: + + Does not recurse into function definitions. + """ + def __init__(self, dependee: Array) -> None: + super().__init__() + self.dependee = dependee + + def rec_idx_or_size_tuple( + self, situp: tuple[IndexOrShapeExpr, ...] + ) -> tuple[frozenset[tuple[Array, ...]], ...]: + return tuple(self.rec(s) for s in situp if isinstance(s, Array)) + + def combine( + self, *args: frozenset[tuple[Array, ...]]) -> frozenset[tuple[Array, ...]]: + from functools import reduce + # FIXME: This doesn't match the docs (original version produced way too + # many results) + combined: frozenset[tuple[Array, ...]] = reduce( + lambda a, b: a | b, args, frozenset()) + if combined: + return frozenset({next(iter(combined))}) + else: + return frozenset() + + def map_index_lambda(self, expr: IndexLambda) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_index_lambda(expr))) + + def map_placeholder(self, expr: Placeholder) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_placeholder(expr))) + + def map_data_wrapper(self, expr: DataWrapper) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_data_wrapper(expr))) + + def map_size_param(self, expr: SizeParam) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_size_param(expr))) + + def map_stack(self, expr: Stack) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_stack(expr))) + + def map_roll(self, expr: Roll) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_roll(expr))) + + def map_axis_permutation( + self, expr: AxisPermutation) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_axis_permutation(expr))) + + def _map_index_base(self, expr: IndexBase) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super()._map_index_base(expr))) + + def map_reshape(self, expr: Reshape) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_reshape(expr))) + + def map_concatenate(self, expr: Concatenate) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_concatenate(expr))) + + def map_einsum(self, expr: Einsum) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_einsum(expr))) + + def map_named_array(self, expr: NamedArray) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_named_array(expr))) + + def map_loopy_call(self, expr: LoopyCall) -> frozenset[tuple[Array, ...]]: + raise AssertionError("Control shouldn't reach this point.") + + def map_loopy_call_result( + self, expr: LoopyCallResult) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_loopy_call_result(expr))) + + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_distributed_send_ref_holder(expr))) + + def map_distributed_recv( + self, expr: DistributedRecv) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_distributed_recv(expr))) + + def map_call(self, expr: Call) -> frozenset[tuple[Array, ...]]: + return self.combine(*(self.rec(bnd) + for bnd in expr.bindings.values())) + + def map_named_call_result( + self, expr: NamedCallResult) -> frozenset[tuple[Array, ...]]: + if expr == self.dependee: + return frozenset({(expr,)}) + return self.combine(*( + frozenset({(expr, *subtrace)}) + for subtrace in super().map_named_call_result(expr))) + + +def trace_dependencies( + outputs: Array | DictOfNamedArrays, dependee: Array + ) -> frozenset[tuple[Array, ...]]: + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + dt = DependencyTracer(dependee) + return dt(outputs) + +# }}} + + # vim: fdm=marker diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 6598b9bcb..2046a9d5d 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1216,6 +1216,9 @@ def map_placeholder(self, expr: Placeholder) -> ResultT: def map_data_wrapper(self, expr: DataWrapper) -> ResultT: return self.combine(*self.rec_idx_or_size_tuple(expr.shape)) + def map_size_param(self, expr: SizeParam) -> ResultT: + return self.combine(cast("ResultT", frozenset({}))) + def map_stack(self, expr: Stack) -> ResultT: return self.combine(*(self.rec(ary) for ary in expr.arrays)) From c0cdbdd88942c42811891cf2f8a7b44738ab8a34 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 7 Mar 2025 17:18:38 -0600 Subject: [PATCH 175/178] add precomputable expression evaluation --- pytato/__init__.py | 2 + pytato/transform/__init__.py | 132 ++++++++++++++++++++++++++++++++++- 2 files changed, 133 insertions(+), 1 deletion(-) diff --git a/pytato/__init__.py b/pytato/__init__.py index 56254b928..f2e15ceb4 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -158,6 +158,7 @@ def set_debug_enabled(flag: bool) -> None: from pytato.target.loopy import LoopyPyOpenCLTarget from pytato.target.loopy.codegen import generate_loopy from pytato.target.python.jax import generate_jax +from pytato.transform import precompute_subexpressions from pytato.transform.calls import inline_calls, tag_all_calls_to_be_inlined from pytato.transform.lower_to_index_lambda import to_index_lambda from pytato.transform.metadata import unify_axes_tags @@ -260,6 +261,7 @@ def set_debug_enabled(flag: bool) -> None: "number_distributed_tags", "ones", "pad", + "precompute_subexpressions", "prod", "real", "reshape", diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index df89f4f1d..542e08026 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -82,7 +82,7 @@ if TYPE_CHECKING: - from collections.abc import Callable, Iterable, Mapping + from collections.abc import Callable, Iterable, Mapping, Sequence ArrayOrNames: TypeAlias = Array | AbstractResultWithNamedArrays @@ -1666,6 +1666,136 @@ def map_call(self, expr: Call) -> frozenset[InputArgumentBase]: # }}} +# {{{ precompute_subexpressions + +# FIXME: Think about what happens when subexpressions contain outlined functions +class _PrecomputableSubexpressionGatherer( + CombineMapper[frozenset[Array], frozenset[Array]]): + """ + Mapper to find subexpressions that do not depend on any placeholders. + """ + def rec(self, expr: ArrayOrNames) -> frozenset[Array]: + inputs = self._make_cache_inputs(expr) + try: + return self._cache_retrieve(inputs) + except KeyError: + result: frozenset[Array] = Mapper.rec(self, expr) + if not isinstance(expr, + Placeholder + | DictOfNamedArrays + | Call): + assert isinstance(expr, Array) + from pytato.analysis import DirectPredecessorsGetter + if result == DirectPredecessorsGetter()(expr): + result = frozenset({expr}) + return self._cache_add(inputs, result) + + # type-ignore reason: incompatible ret. type with super class + def __call__(self, expr: ArrayOrNames) -> frozenset[Array]: # type: ignore + subexprs = self.rec(expr) + + # Need to treat data arrays as precomputable during recursion, but afterwards + # we only care about larger expressions containing them *or* their shape if + # it's a non-constant expression + # FIXME: Does it even make sense for a data array to have an expression as + # a shape? Maybe this isn't necessary... + + data_subexprs = { + ary + for ary in subexprs + if isinstance(ary, DataWrapper | DistributedRecv)} + + subexprs -= data_subexprs + + for ary in data_subexprs: + subexprs |= self.combine(*self.rec_idx_or_size_tuple(ary.shape)) + + return subexprs + + def combine(self, *args: frozenset[Array]) -> frozenset[Array]: + from functools import reduce + return reduce(lambda a, b: a | b, args, frozenset()) + + def map_function_definition(self, expr: FunctionDefinition) -> frozenset[Array]: + # FIXME: Ignoring subexpressions inside function definitions for now + return frozenset() + + def map_call(self, expr: Call) -> frozenset[Array]: + rec_fn = self.rec_function_definition(expr.function) + assert not rec_fn + rec_bindings: Mapping[str, frozenset[Array]] = immutabledict({ + name: self.rec(bnd) if isinstance(bnd, Array) else frozenset({bnd}) + for name, bnd in expr.bindings.items()}) + if all( + rec_bindings[name] == frozenset({expr.bindings[name]}) + for name in expr.bindings): + # FIXME: This conflicts with type annotations + return frozenset({expr}) + else: + return self.combine(rec_fn, *rec_bindings.values()) + + +class _PrecomputableSubexpressionReplacer(CopyMapper): + """ + Mapper to replace precomputable subexpressions found by + :class:`_PrecomputableSubexpressionGatherer` with the evaluated versions. + """ + def __init__( + self, + replacement_map: Mapping[Array, Array], + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: + TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + super().__init__(_cache=_cache, _function_cache=_function_cache) + self.replacement_map = replacement_map + + def rec(self, expr: ArrayOrNames) -> ArrayOrNames: + inputs = self._make_cache_inputs(expr) + try: + return self._cache_retrieve(inputs) + except KeyError: + result: ArrayOrNames | None = None + if isinstance(expr, Array): + try: + result = self.replacement_map[expr] + except KeyError: + pass + result = self.rec(result) if result is not None else Mapper.rec(self, expr) + return self._cache_add(inputs, result) + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + {}, + _function_cache=cast( + "TransformMapperCache[FunctionDefinition, []]",self._function_cache)) + + +def precompute_subexpressions( + expr: ArrayOrNames, + # FIXME: Don't use Sequence for this + eval_func: Callable[[Sequence[ArrayOrNames]], Sequence[ArrayOrNames]] + ) -> ArrayOrNames: + """Evaluate subexpressions in *expr* that do not depend on any placeholders.""" + precomputable_subexprs = _PrecomputableSubexpressionGatherer()(expr) + for subexpr in precomputable_subexprs: + from pytato.analysis import get_num_nodes + nnodes = get_num_nodes(subexpr) + if nnodes > 1: + print( + "Found precomputable subexpression of type " + f"{type(subexpr).__name__} with {nnodes} nodes.") + from pytools.obj_array import make_obj_array + # FIXME: Don't use object array + precomputable_subexprs_ary = make_obj_array(list(precomputable_subexprs)) + evaled_subexprs_ary = eval_func(precomputable_subexprs_ary) + subexpr_to_evaled_subexpr = dict( + zip(precomputable_subexprs_ary, evaled_subexprs_ary, strict=True)) + return _PrecomputableSubexpressionReplacer(subexpr_to_evaled_subexpr)(expr) + +# }}} + + # {{{ SizeParamGatherer class SizeParamGatherer( From 3e0de764c4e88ccc9c250f6cfdf7aa13dfe47ed4 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 7 Mar 2025 17:40:38 -0600 Subject: [PATCH 176/178] [NOPRODUCTION] squashed concatenation draft --- pytato/__init__.py | 9 +- pytato/array.py | 19 + pytato/equality.py | 293 ++- pytato/tags.py | 30 + pytato/transform/__init__.py | 57 +- pytato/transform/calls.py | 2086 ++++++++++++++++++++- pytato/transform/lower_to_index_lambda.py | 57 +- pytato/transform/metadata.py | 55 +- pytato/utils.py | 1 + test/test_codegen.py | 119 +- 10 files changed, 2672 insertions(+), 54 deletions(-) diff --git a/pytato/__init__.py b/pytato/__init__.py index f2e15ceb4..c6066ac99 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -159,7 +159,12 @@ def set_debug_enabled(flag: bool) -> None: from pytato.target.loopy.codegen import generate_loopy from pytato.target.python.jax import generate_jax from pytato.transform import precompute_subexpressions -from pytato.transform.calls import inline_calls, tag_all_calls_to_be_inlined +from pytato.transform.calls import ( + concatenate_calls, + inline_calls, + tag_all_calls_to_be_inlined, + zero_unused_call_bindings, +) from pytato.transform.lower_to_index_lambda import to_index_lambda from pytato.transform.metadata import unify_axes_tags from pytato.transform.remove_broadcasts_einsum import rewrite_einsums_with_no_broadcasts @@ -216,6 +221,7 @@ def set_debug_enabled(flag: bool) -> None: "arctan2", "broadcast_to", "concatenate", + "concatenate_calls", "conj", "cos", "cosh", @@ -289,5 +295,6 @@ def set_debug_enabled(flag: bool) -> None: "vdot", "verify_distributed_partition", "where", + "zero_unused_call_bindings", "zeros", ) diff --git a/pytato/array.py b/pytato/array.py index 27b2f1f4d..267bdeaa5 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -2046,6 +2046,25 @@ def _get_created_at_tag(stacklevel: int = 1) -> frozenset[Tag]: return frozenset({CreatedAt(_PytatoStackSummary(frames))}) +def _inherit_created_at_tag_from(ary: Array, src_ary: Array) -> Array: + from pytato.tags import CreatedAt + try: + tb_tag = next( + tag for tag in src_ary.non_equality_tags + if isinstance(tag, CreatedAt)) + except StopIteration: + tb_tag = None + + if tb_tag is not None: + return attrs.evolve( + ary, + non_equality_tags=frozenset({ + tb_tag if isinstance(tag, CreatedAt) else tag + for tag in ary.non_equality_tags})) + else: + return ary + + def _get_default_tags() -> frozenset[Tag]: return frozenset() diff --git a/pytato/equality.py b/pytato/equality.py index 47bf7a0dc..4b95df358 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -47,7 +47,8 @@ SizeParam, Stack, ) -from pytato.function import FunctionDefinition +from pytato.function import Call, FunctionDefinition, NamedCallResult +from pytato.tags import Tag if TYPE_CHECKING: @@ -59,6 +60,7 @@ __doc__ = """ .. autoclass:: EqualityComparer +.. autoclass:: SimilarityComparer """ @@ -328,4 +330,293 @@ def map_named_call_result(self, expr1: NamedCallResult, expr2: Any) -> bool: # }}} + +# {{{ SimilarityComparer + +class SimilarityComparer: + """ + A :class:`pytato.array.Array` visitor to check structural similarity between two + expression DAGs. Data and array shapes are allowed to be different. + + .. note:: + + - Compares two expression graphs ``expr1``, ``expr2`` in :math:`O(N)` + comparisons, where :math:`N` is the number of nodes in ``expr1``. + - This visitor was introduced to memoize the sub-expression comparisons + of the expressions to be compared. Not memoizing the sub-expression + comparisons results in :math:`O(2^N)` complexity for the comparison + operation, where :math:`N` is the number of nodes in expressions. See + `GH-Issue-163 ` for + more on this. + """ + def __init__( + self, + # FIXME: tuple? + ignore_tag_types: frozenset(type) | None = None, + err_on_not_similar: bool = False) -> None: + # Uses the same cache for both arrays and functions + self._cache: dict[tuple[int, int], bool] = {} + if ignore_tag_types is None: + ignore_tag_types: frozenset(type) = frozenset() + self.ignore_tag_types = tuple(ignore_tag_types) + self.err_on_not_similar = err_on_not_similar + + def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: Any) -> bool: + cache_key = id(expr1), id(expr2) + try: + return self._cache[cache_key] + except KeyError: + method: Callable[ + [Array | AbstractResultWithNamedArrays | FunctionDefinition, Any], + bool] + + try: + method = ( + getattr(self, expr1._mapper_method) + if isinstance(expr1, (Array, AbstractResultWithNamedArrays)) + else self.map_function_definition) + except AttributeError: + if isinstance(expr1, Array): + result = self.handle_unsupported_array(expr1, expr2) + else: + result = self.map_foreign(expr1, expr2) + else: + result = (expr1 is expr2) or method(expr1, expr2) + + if self.err_on_not_similar and not result: + raise ValueError(f"Not similar, {type(expr1).__name__}, {type(expr2).__name__}") + + self._cache[cache_key] = result + return result + + def __call__(self, expr1: ArrayOrNames, expr2: Any + ) -> bool: + return self.rec(expr1, expr2) + + def handle_unsupported_array(self, expr1: Array, + expr2: Any) -> bool: + raise NotImplementedError(type(expr1).__name__) + + def map_foreign(self, expr1: Any, expr2: Any) -> bool: + raise NotImplementedError(type(expr1).__name__) + + def _map_tags(self, tags1: frozenset(Tag), tags2: frozenset(Tag)) -> bool: + filtered_tags1 = frozenset( + tag for tag in tags1 if not isinstance(tag, self.ignore_tag_types)) + filtered_tags2 = frozenset( + tag for tag in tags2 if not isinstance(tag, self.ignore_tag_types)) + return filtered_tags1 == filtered_tags2 + + def map_placeholder(self, expr1: Placeholder, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.name == expr2.name + and len(expr1.shape) == len(expr2.shape) + and expr1.dtype == expr2.dtype + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def map_size_param(self, expr1: SizeParam, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.name == expr2.name + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def map_data_wrapper(self, expr1: DataWrapper, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.data.__class__ is expr2.data.__class__ + and expr1.name == expr2.name + and len(expr1.shape) == len(expr2.shape) + and all(self.rec(dim1, dim2) + for dim1, dim2 in zip(expr1.shape, expr2.shape) + if isinstance(dim1, Array)) + and expr1.dtype == expr2.dtype + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def map_index_lambda(self, expr1: IndexLambda, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.expr == expr2.expr + and (frozenset(expr1.bindings.keys()) + == frozenset(expr2.bindings.keys())) + and all(self.rec(expr1.bindings[name], expr2.bindings[name]) + for name in expr1.bindings) + and len(expr1.shape) == len(expr2.shape) + and all(self.rec(dim1, dim2) + for dim1, dim2 in zip(expr1.shape, expr2.shape) + if isinstance(dim1, Array)) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + and expr1.var_to_reduction_descr == expr2.var_to_reduction_descr + ) + + def map_stack(self, expr1: Stack, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.axis == expr2.axis + and len(expr1.arrays) == len(expr2.arrays) + and all(self.rec(ary1, ary2) + for ary1, ary2 in zip(expr1.arrays, expr2.arrays)) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def map_concatenate(self, expr1: Concatenate, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.axis == expr2.axis + and len(expr1.arrays) == len(expr2.arrays) + and all(self.rec(ary1, ary2) + for ary1, ary2 in zip(expr1.arrays, expr2.arrays)) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def map_roll(self, expr1: Roll, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.axis == expr2.axis + and expr1.shift == expr2.shift + and self.rec(expr1.array, expr2.array) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def map_axis_permutation(self, expr1: AxisPermutation, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.axis_permutation == expr2.axis_permutation + and self.rec(expr1.array, expr2.array) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def _map_index_base(self, expr1: IndexBase, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and self.rec(expr1.array, expr2.array) + and len(expr1.indices) == len(expr2.indices) + and all(self.rec(idx1, idx2) + if (isinstance(idx1, Array) + and isinstance(idx2, Array)) + else idx1 == idx2 + for idx1, idx2 in zip(expr1.indices, expr2.indices)) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def map_basic_index(self, expr1: BasicIndex, expr2: Any) -> bool: + return self._map_index_base(expr1, expr2) + + def map_contiguous_advanced_index(self, + expr1: AdvancedIndexInContiguousAxes, + expr2: Any + ) -> bool: + return self._map_index_base(expr1, expr2) + + def map_non_contiguous_advanced_index(self, + expr1: AdvancedIndexInNoncontiguousAxes, + expr2: Any + ) -> bool: + return self._map_index_base(expr1, expr2) + + def map_reshape(self, expr1: Reshape, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and len(expr1.newshape) == len(expr2.newshape) + and self.rec(expr1.array, expr2.array) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + ) + + def map_einsum(self, expr1: Einsum, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.access_descriptors == expr2.access_descriptors + and all(self.rec(ary1, ary2) + for ary1, ary2 in zip(expr1.args, + expr2.args)) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + and expr1.redn_axis_to_redn_descr == expr2.redn_axis_to_redn_descr + ) + + def map_named_array(self, expr1: NamedArray, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and self.rec(expr1._container, expr2._container) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + and expr1.name == expr2.name) + + def map_loopy_call(self, expr1: LoopyCall, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.translation_unit == expr2.translation_unit + and expr1.entrypoint == expr2.entrypoint + and frozenset(expr1.bindings) == frozenset(expr2.bindings) + and all(self.rec(bnd, + expr2.bindings[name]) + if isinstance(bnd, Array) + else bnd == expr2.bindings[name] + for name, bnd in expr1.bindings.items()) + and self._map_tags(expr1.tags, expr2.tags) + ) + + def map_loopy_call_result(self, expr1: LoopyCallResult, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and self.rec(expr1._container, expr2._container) + and self._map_tags(expr1.tags, expr2.tags) + and expr1.axes == expr2.axes + and expr1.name == expr2.name) + + def map_dict_of_named_arrays(self, expr1: DictOfNamedArrays, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and frozenset(expr1._data.keys()) == frozenset(expr2._data.keys()) + and all(self.rec(expr1._data[name], expr2._data[name]) + for name in expr1._data) + and self._map_tags(expr1.tags, expr2.tags) + ) + + def map_distributed_send_ref_holder( + self, expr1: DistributedSendRefHolder, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and self.rec(expr1.send.data, expr2.send.data) + and self.rec(expr1.passthrough_data, expr2.passthrough_data) + and expr1.send.dest_rank == expr2.send.dest_rank + and expr1.send.comm_tag == expr2.send.comm_tag + and expr1.send.tags == expr2.send.tags + and self._map_tags(expr1.tags, expr2.tags) + ) + + def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.src_rank == expr2.src_rank + and expr1.comm_tag == expr2.comm_tag + and len(expr1.shape) == len(expr2.shape) + and expr1.dtype == expr2.dtype + and self._map_tags(expr1.tags, expr2.tags) + ) + + def map_function_definition(self, expr1: FunctionDefinition, expr2: Any + ) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.parameters == expr2.parameters + and expr1.return_type == expr2.return_type + and (set(expr1.returns.keys()) == set(expr2.returns.keys())) + and all(self.rec(expr1.returns[k], expr2.returns[k]) + for k in expr1.returns) + and self._map_tags(expr1.tags, expr2.tags) + ) + + def map_call(self, expr1: Call, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and self.rec(expr1.function, expr2.function) + and frozenset(expr1.bindings) == frozenset(expr2.bindings) + and all(self.rec(bnd, + expr2.bindings[name]) + for name, bnd in expr1.bindings.items()) + and self._map_tags(expr1.tags, expr2.tags) + ) + + def map_named_call_result(self, expr1: NamedCallResult, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.name == expr2.name + and self.rec(expr1._container, expr2._container)) + +# }}} + # vim: fdm=marker diff --git a/pytato/tags.py b/pytato/tags.py index e0a98b7da..e798a7c8f 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -17,6 +17,9 @@ .. autoclass:: FunctionIdentifier .. autoclass:: CallImplementationTag .. autoclass:: InlineCallTag +.. autoclass:: UseInputAxis +.. autoclass:: ConcatenatedCallInputConcatAxisTag +.. autoclass:: ConcatenatedCallOutputSliceAxisTag """ from dataclasses import dataclass @@ -233,3 +236,30 @@ class InlineCallTag(CallImplementationTag): A :class:`CallImplementationTag` that directs the :class:`pytato.target.Target` to inline the call site. """ + + +@dataclass(frozen=True) +class UseInputAxis(UniqueTag): + """ + A placeholder axis tag indicating that an array should derive tags from one of + its inputs. + """ + key: Hashable + axis: int + + +@dataclass(frozen=True) +class ConcatenatedCallInputConcatAxisTag(UniqueTag): + """ + An axis tag indicating that an array is a concatenation of multiple + inputs resulting from the transformations done in + :func:`pytato.concatenate_calls`. + """ + + +@dataclass(frozen=True) +class ConcatenatedCallOutputSliceAxisTag(UniqueTag): + """ + An axis tag indicating that an array is a slice of a concatenated output + resulting from the transformations done in :func:`pytato.concatenate_calls`. + """ diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index fdcf375ea..04e373f22 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -116,6 +116,7 @@ .. autofunction:: map_and_copy .. autofunction:: materialize_with_mpms .. autofunction:: deduplicate_data_wrappers +.. autofunction:: unify_materialization_tags .. automodule:: pytato.transform.lower_to_index_lambda .. automodule:: pytato.transform.remove_broadcasts_einsum .. automodule:: pytato.transform.einsum_distributive_law @@ -2814,33 +2815,6 @@ def rec_get_user_nodes(expr: ArrayOrNames, # }}} -# {{{ BranchMorpher - -class BranchMorpher(CopyMapper): - """ - A mapper that replaces equal segments of graphs with identical objects. - """ - def __init__(self) -> None: - super().__init__() - self.result_cache: Dict[ArrayOrNames, ArrayOrNames] = {} - - def cache_key(self, expr: CachedMapperT) -> Any: - return (id(expr), expr) - - # type-ignore reason: incompatible with Mapper.rec - def rec(self, expr: MappedT) -> MappedT: # type: ignore[override] - rec_expr = super().rec(expr) - try: - # type-ignored because 'result_cache' maps to ArrayOrNames - return self.result_cache[rec_expr] # type: ignore[return-value] - except KeyError: - self.result_cache[rec_expr] = rec_expr - # type-ignored because of super-class' relaxed types - return rec_expr # type: ignore[no-any-return] - -# }}} - - # {{{ deduplicate_data_wrappers class DataWrapperDeduplicator(CopyMapper): @@ -2941,9 +2915,36 @@ def deduplicate_data_wrappers(array_or_names: ArrayOrNames) -> ArrayOrNames: dedup.data_wrappers_encountered - len(dedup.data_wrapper_cache))) - return BranchMorpher()(array_or_names) + return array_or_names # }}} +# {{{ unify_materialization_tags + +def unify_materialization_tags(array_or_names: ArrayOrNames) -> ArrayOrNames: + """ + For the expression graph given as *array_or_names*, replace all + non-materialized subexpressions with the corresponding materialized version if + one exists elsewhere in the DAG. + """ + from pytato.analysis import collect_materialized_nodes + materialized_exprs = collect_materialized_nodes(array_or_names) + + non_materialized_expr_to_materialized_expr = { + expr.without_tags(ImplStored()): expr + for expr in materialized_exprs} + + def unify(expr): + if expr.tags_of_type(ImplStored): + return expr + try: + return non_materialized_expr_to_materialized_expr[expr] + except KeyError: + return expr + + return map_and_copy(array_or_names, unify) + +# }}} + # vim: foldmethod=marker diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index 2cf8c6479..f3e68302b 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -1,11 +1,16 @@ -""" +from __future__ import annotations + + +__doc__ = """ .. currentmodule:: pytato.transform.calls .. autofunction:: inline_calls +.. autofunction:: concatenate_calls .. autofunction:: tag_all_calls_to_be_inlined -""" -from __future__ import annotations +.. autofunction:: zero_unused_call_bindings +.. autoclass:: CallSiteLocation +""" __copyright__ = "Copyright (C) 2022 Kaushik Kulkarni" @@ -29,29 +34,92 @@ THE SOFTWARE. """ +import itertools +import logging +import numpy as np +from functools import partialmethod, reduce +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Collection, + Dict, + FrozenSet, + Generator, + List, + Never, + Sequence, + Set, + Tuple, + cast, +) +from typing_extensions import Self -from typing import TYPE_CHECKING, cast +import attrs +from immutabledict import immutabledict -from typing_extensions import Self +import pymbolic.primitives as prim +from pytools import memoize_method, memoize_on_first_arg +import pytato.scalar_expr as scalar_expr +from pytato.analysis import collect_nodes_of_type from pytato.array import ( AbstractResultWithNamedArrays, Array, + AxisPermutation, + BasicIndex, + Concatenate, + DataWrapper, DictOfNamedArrays, + Einsum, + IndexBase, + IndexLambda, + InputArgumentBase, Placeholder, + Reshape, + Roll, + ShapeComponent, + ShapeType, + SizeParam, + Stack, + concatenate, + zeros, ) from pytato.function import Call, FunctionDefinition, NamedCallResult -from pytato.tags import InlineCallTag + + +if TYPE_CHECKING: + from collections.abc import Mapping +from pytato.tags import ( + ConcatenatedCallInputConcatAxisTag, + ConcatenatedCallOutputSliceAxisTag, + FunctionIdentifier, + ImplStored, + InlineCallTag, + UseInputAxis, +) from pytato.transform import ( ArrayOrNames, + CachedMapper, + CachedWalkMapper, + CombineMapper, CopyMapper, + Deduplicator, + InputGatherer, TransformMapperCache, + TransformMapperWithExtraArgs, _verify_is_array, ) +from pytato.transform.lower_to_index_lambda import to_index_lambda +from pytato.utils import are_shape_components_equal if TYPE_CHECKING: - from collections.abc import Mapping + from pytato.loopy import LoopyCallResult + +logger = logging.getLogger(__name__) + +ArrayOnStackT = Tuple[Tuple[Call, ...], Array] # {{{ inlining @@ -168,4 +236,2008 @@ def tag_all_calls_to_be_inlined(expr: ArrayOrNames) -> ArrayOrNames: # }}} + +# {{{ _collect_used_call_inputs + +class _UsedCallInputCollector(CachedWalkMapper[[]]): + def __init__( + self, + _fn_input_gatherers: + dict[FunctionDefinition, InputGatherer] | None = None, + _visited_functions: set[Any] | None = None + ) -> None: + if _fn_input_gatherers is None: + _fn_input_gatherers = {} + + self.call_to_used_inputs: dict[Call, set[Placeholder]] = {} + self._fn_input_gatherers = _fn_input_gatherers + + super().__init__(_visited_functions=_visited_functions) + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + _fn_input_gatherers=self._fn_input_gatherers, + _visited_functions=self._visited_functions) + + def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: + return expr + + def get_function_definition_cache_key( + self, expr: FunctionDefinition) -> FunctionDefinition: + return expr + + # type-ignore-reason: CachedWalkMapper's method takes in variadic args, kwargs + def map_named_call_result( + self, expr: NamedCallResult, # type: ignore[override] + ) -> None: + call = expr._container + try: + input_gatherer = self._fn_input_gatherers[call.function] + except KeyError: + input_gatherer = InputGatherer() + self._fn_input_gatherers[call.function] = input_gatherer + + used_inputs = self.call_to_used_inputs.setdefault(call, set()) + used_inputs |= input_gatherer(call.function.returns[expr.name]) + + super().map_named_call_result(expr) + + +def _collect_used_call_inputs( + expr: ArrayOrNames) -> immutabledict[Call, frozenset[Placeholder]]: + """ + Returns a mapping from :class:`~pytato.function.Call` to the set of input + :class:`~pt.array.Placeholder`\ s belonging to its function definition that are + actually used by the expression. In other words, it returns the inputs + corresponding to the call bindings that would remain in the DAG if the call was + inlined. + """ + collector = _UsedCallInputCollector() + collector(expr) + + return immutabledict({ + call: frozenset(inputs) + for call, inputs in collector.call_to_used_inputs.items()}) + +# }}} + + +# {{{ zero_unused_call_bindings + +class _UnusedCallBindingZeroer(CopyMapper): + """ + Mapper to replace unused bindings of :class:`~pytato.function.Call` with zeros + of appropriate shape. + """ + def __init__( + self, + call_to_used_inputs: Mapping[Call, frozenset[Placeholder]], + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + super().__init__(_cache=_cache, _function_cache=_function_cache) + self.call_to_used_inputs = call_to_used_inputs + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + call_to_used_inputs=self.call_to_used_inputs, + _function_cache=self._function_cache) + + def map_call(self, expr: Call) -> Call: + new_function = self.rec_function_definition(expr.function) + new_bindings = {} + for name, bnd in expr.bindings.items(): + if isinstance(bnd, Array): + if ( + expr.function.get_placeholder(name) + in self.call_to_used_inputs[expr]): + new_bnd = self.rec(bnd) + else: + new_bnd = zeros(bnd.shape, bnd.dtype) + else: + new_bnd = bnd + new_bindings[name] = new_bnd + if ( + new_function is expr.function + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values()))): + return expr + else: + return Call(new_function, immutabledict(new_bindings), tags=expr.tags) + + +def zero_unused_call_bindings(expr: ArrayOrNames) -> ArrayOrNames: + """ + Replaces :class:`~pytato.function.Call` bindings that are not used by the + expression with arrays of zeros of the appropriate shape. This can be necessary + for certain transformations such as concatenation, where otherwise bindings + may be retained in the DAG when they should be dropped. + """ + call_to_used_inputs = _collect_used_call_inputs(expr) + return _UnusedCallBindingZeroer(call_to_used_inputs)(expr) + +# }}} + + +# {{{ Concatenatability + +@attrs.define(frozen=True) +class Concatenatability: + """ + Describes how a particular array expression can be concatenated. + """ + + +@attrs.define(frozen=True) +class ConcatableAlongAxis(Concatenatability): + """ + Used to describe an array expression that is concatenatable along *axis*. + """ + axis: int + + +@attrs.define(frozen=True) +class ConcatableIfConstant(Concatenatability): + """ + Used to describe an array expression in a function body that can be + concatenated only if the expression is the same across call-sites. + """ + +# }}} + + +# {{{ concatenate_calls + +@attrs.define(frozen=True) +class CallSiteLocation: + r""" + Records a call-site's location in a :mod:`pytato` expression. + + .. attribute:: call + + The instance of :class:`~pytato.function.Call` being called at this + location. + + .. attribute:: stack + + The call sites within which this particular call is called. + For eg. if ``stack = (c1, c2)``, then :attr:`call` is called within + ``c2``\ 's function body which itself is called from ``c1``\ 's + function body. + """ + call: Call + stack: Tuple[Call, ...] + + +class CallSiteDependencyCollector( + CombineMapper[FrozenSet[CallSiteLocation], Never]): + r""" + Collects all the call sites in a :mod:`pytato` expression along with their + interdependencies. + + .. attribute:: stack + + The stack of calls at which the calls are being collected. This + attribute is used to specify :attr:`CallSiteLocation.stack` in the + :class:`CallSiteLocation`\ s being built. Must be altered (by creating + a new instance of the mapper) before entering the function body of a + new :class:`~pytato.function.Call`. + + .. attribute:: call_site_to_dep_call_sites + + A mapping from call site to the call sites on which it depends, for each + call site present in the expression. + """ + def __init__(self, stack: Tuple[Call, ...]) -> None: + self.stack = stack + self.call_site_to_dep_call_sites: \ + Dict[CallSiteLocation, CallSiteLocation] = {} + super().__init__() + + def combine(self, *args: FrozenSet[CallSiteLocation] + ) -> FrozenSet[CallSiteLocation]: + return reduce(lambda a, b: a | b, args, frozenset()) + + def map_size_param(self, expr: SizeParam) -> FrozenSet[CallSiteLocation]: + return frozenset() + + def map_call(self, expr: Call) -> FrozenSet[CallSiteLocation]: + cs = CallSiteLocation(expr, self.stack) + + new_mapper_for_fn = CallSiteDependencyCollector(stack=self.stack + (expr,)) + dependent_call_sites = self.combine( + *[ + self.rec(bnd) for bnd in expr.bindings.values() + if isinstance(bnd, Array)], + *[new_mapper_for_fn(ret) + for ret in expr.function.returns.values()]) + + self.call_site_to_dep_call_sites[cs] = dependent_call_sites + self.call_site_to_dep_call_sites.update( + new_mapper_for_fn.call_site_to_dep_call_sites) + + return self.combine(frozenset([cs]), dependent_call_sites) + + +class _NamedCallResultReplacerPostConcatenate(CopyMapper): + """ + Mapper to replace instances of :class:`~pytato.function.NamedCallResult` as + per :attr:`replacement_map`. + + .. attribute:: current_stack + + Records the stack to track which function body the mapper is + traversing. Must be altered (by creating a new instance) before + entering the function body of a new :class:`~pytato.function.Call`. + """ + def __init__( + self, + replacement_map: Mapping[ + Tuple[ + NamedCallResult, + Tuple[Call, ...]], + Array], + current_stack: Tuple[Call, ...], + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + super().__init__(_cache=_cache, _function_cache=_function_cache) + self.replacement_map = replacement_map + self.current_stack = current_stack + + def clone_for_callee( + self, function: FunctionDefinition) -> Self: + raise AssertionError("Control should not reach here." + " Call clone_with_new_call_on_stack instead.") + + def clone_with_new_call_on_stack(self, expr: Call) -> Self: + # type-ignore-reason: Mapper class does not define these attributes. + return type(self)( # type: ignore[call-arg] + self.replacement_map, # type: ignore[attr-defined] + self.current_stack + (expr,), # type: ignore[attr-defined] + _function_cache=self._function_cache + ) + + def map_function_definition( + self, expr: FunctionDefinition) -> FunctionDefinition: + # No clone here because we're cloning in map_call instead + new_returns = {name: self.rec(ret) + for name, ret in expr.returns.items()} + if all( + new_ret is ret + for ret, new_ret in zip( + expr.returns.values(), + new_returns.values())): + return expr + else: + return attrs.evolve(expr, returns=immutabledict(new_returns)) + + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: + new_mapper = self.clone_with_new_call_on_stack(expr) + new_function = new_mapper.rec_function_definition(expr.function) + new_bindings = { + name: self.rec(bnd) if isinstance(bnd, Array) else bnd + for name, bnd in expr.bindings.items()} + if ( + new_function is expr.function + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values()))): + return expr + else: + return Call(new_function, immutabledict(new_bindings), tags=expr.tags) + + def map_named_call_result(self, expr: NamedCallResult) -> Array: + try: + new_expr = self.replacement_map[expr, self.current_stack] + if isinstance(new_expr, NamedCallResult): + return super().map_named_call_result(new_expr) + else: + return self.rec(new_expr) + except KeyError: + return super().map_named_call_result(expr) + + +def _have_same_axis_length(arrays: Collection[Array], + iaxis: int) -> bool: + """ + Returns *True* only if every array in *arrays* have the same axis length + along *iaxis*. + """ + axis_length = next(iter(arrays)).shape[iaxis] + return all(are_shape_components_equal(other_ary.shape[iaxis], + axis_length) + for other_ary in arrays) + + +def _have_same_axis_length_except(arrays: Collection[Array], + iaxis: int) -> bool: + """ + Returns *True* only if every array in *arrays* have the same + dimensionality and have axes with the same lengths except along the + *iaxis*-axis. + """ + ndim = next(iter(arrays)).ndim + return (all(ary.ndim == ndim for ary in arrays) + and all(_have_same_axis_length(arrays, idim) + for idim in range(ndim) + if idim != iaxis)) + + +@attrs.define(frozen=True) +class _InputConcatabilityGetterAcc: + r""" + Return type for :class:`_InputConcatabilityGetter`. An instance of this class is + returned after mapping a :class:`~pytato.Array` expression. + + .. attribute:: seen_inputs + + A :class:`frozenset` of all :class:`pytato.InputArgumentBase` + predecessors of a node. + + .. attribute:: input_concatability + + Records the constraints that come along with concatenating the array + being mapped. The constraints are recorded as a mapping from the axes + of the array being mapped to the axes of the input arguments. This + mapping informs us which axes in the :class:`InputArgumentBase`\ s' + must be concatenated to soundly concatenate a particular axis in the + array being mapped. The axes in this mapping are represented using + :class:`int`. If certain axes are missing in this mapping, then + concatenation cannot be performed along those axes for the mapped + array. + """ + seen_inputs: FrozenSet[InputArgumentBase] + input_concatability: Mapping[Concatenatability, + Mapping[InputArgumentBase, Concatenatability]] + + def __post_init__(self) -> None: + assert all( + frozenset(input_concat.keys()) == self.seen_inputs + for input_concat in self.input_concatability.values()) + + __attrs_post_init__ = __post_init__ + + +class NonConcatableExpression(RuntimeError): + """ + Used internally by :class:`_ScalarExprConcatabilityMapper`. + """ + + +class _InvalidConcatenatability(RuntimeError): + """ + Used internally by :func:`_get_ary_to_concatenatabilities`. + """ + + +class _ScalarExprConcatabilityMapper(scalar_expr.CombineMapper): + """ + Maps :attr:`~pytato.array.IndexLambda.expr` to the axes of the bindings + that must be concatenated to concatenate the IndexLambda's + :attr:`iaxis`-axis. + + .. attribute:: allow_indirect_addr + + If *True* indirect access are allowed. However, concatenating along the + iaxis-axis would be sound only if the binding which is being indexed + into is same for all the expressions to be concatenated. + """ + def __init__(self, iaxis: int, allow_indirect_addr: bool) -> None: + self.iaxis = iaxis + self.allow_indirect_addr = allow_indirect_addr + super().__init__() + + def combine(self, values: Collection[Mapping[str, Concatenatability]] + ) -> Mapping[str, Concatenatability]: + result: Dict[str, Concatenatability] = {} + for value in values: + for bnd_name, iaxis in value.items(): + try: + if result[bnd_name] != iaxis: + # only one axis of a particular binding can be + # concatenated. If multiple axes must be concatenated + # that means the index lambda is not + # iaxis-concatenatable. + raise NonConcatableExpression + except KeyError: + result[bnd_name] = iaxis + + return immutabledict(result) + + def map_variable(self, expr: prim.Variable) -> Mapping[str, Concatenatability]: + if expr.name == f"_{self.iaxis}": + raise NonConcatableExpression + else: + return immutabledict() + + def map_constant(self, expr: Any) -> Mapping[str, Concatenatability]: + return immutabledict() + + map_nan = map_constant + + def map_subscript(self, expr: prim.Subscript + ) -> Mapping[str, Concatenatability]: + name: str = expr.aggregate.name + rec_indices: List[Mapping[str, Concatenatability]] = [] + for iaxis, idx in enumerate(expr.index_tuple): + if idx == prim.Variable(f"_{self.iaxis}"): + rec_indices.append({name: ConcatableAlongAxis(iaxis)}) + else: + rec_idx = self.rec(idx) + if rec_idx: + if not self.allow_indirect_addr: + raise NonConcatableExpression + else: + # indirect accesses cannot be concatenated in the general + # case unless the indexee is the same for the + # expression graphs being concatenated. + pass + rec_indices.append(rec_idx) + + combined_rec_indices = dict(self.combine(rec_indices)) + + if name not in combined_rec_indices: + combined_rec_indices[name] = ConcatableIfConstant() + + return immutabledict(combined_rec_indices) + + +@memoize_on_first_arg +def _get_binding_to_concatenatability_scalar_expr( + expr: scalar_expr.ScalarExpression, + iaxis: int, + allow_indirect_addr: bool) -> Mapping[str, Concatenatability]: + mapper = _ScalarExprConcatabilityMapper(iaxis, allow_indirect_addr) + return mapper(expr) # type: ignore[no-any-return] + + + +def _get_binding_to_concatenatability(expr: scalar_expr.ScalarExpression, + iaxis: int, + allow_indirect_addr: bool, + ) -> Mapping[str, Concatenatability]: + """ + Maps *expr* using :class:`_ScalarExprConcatabilityMapper`. + """ + if np.isscalar(expr): + # In some cases expr may just be a number, which can't be memoized on + return {} + + return _get_binding_to_concatenatability_scalar_expr( + expr, iaxis, allow_indirect_addr) + + +def _combine_input_accs( + operand_accs: Tuple[_InputConcatabilityGetterAcc, ...], + concat_to_operand_concatabilities: Mapping[Concatenatability, + Tuple[Concatenatability, ...] + ], +) -> _InputConcatabilityGetterAcc: + """ + For an index lambda ``I`` with operands ``I1, I2, .. IN`` that specify their + concatenatability constraints using *operand_accs*, this routine returns + the axes concatenation constaints of ``I``. + + :arg concat_to_operand_concatabilities: Mapping of the form ``concat_I -> + (C_I1, C_I2, ..., C_IN)`` specifying the concatabilities of the + operands ``I1, I2, .., IN`` in order to concatenate the + ``I`` axis via the criterion ``conncat_I``. + """ + + input_concatabilities: Dict[Concatenatability, Mapping[InputArgumentBase, + Concatenatability]] = {} + seen_inputs: FrozenSet[InputArgumentBase] = reduce( + frozenset.union, + (operand_acc.seen_inputs for operand_acc in operand_accs), + frozenset()) + + # The core logic here is to filter the iaxis in out_axis_to_operand_axes + # so that all the operands agree on how the input arguments must be + # concatenated. + + for out_concat, operand_concatabilities in (concat_to_operand_concatabilities + .items()): + is_i_out_axis_concatenatable = True + input_concatability: Dict[InputArgumentBase, Concatenatability] = {} + + for operand_concatability, operand_acc in zip(operand_concatabilities, + operand_accs, + strict=True): + if operand_concatability not in ( + operand_acc.input_concatability): + # required operand concatability cannot be achieved + # => out_concat cannot be concatenated + is_i_out_axis_concatenatable = False + break + + for input_arg, input_concat in ( + operand_acc + .input_concatability[operand_concatability] + .items()): + try: + if input_concatability[input_arg] != input_concat: + is_i_out_axis_concatenatable = False + break + except KeyError: + input_concatability[input_arg] = input_concat + if not is_i_out_axis_concatenatable: + break + + if is_i_out_axis_concatenatable: + input_concatabilities[out_concat] = immutabledict(input_concatability) + + return _InputConcatabilityGetterAcc(seen_inputs, + immutabledict(input_concatabilities)) + + +@attrs.define(frozen=True) +class FunctionConcatenability: + r""" + Records a valid concatenatability criterion for a + :class:`pytato.function.FunctionDefinition`. + + .. attribute:: output_to_concatenatability + + A mapping from the name of a + :class:`FunctionDefinition`\ 's returned array to how it should be + concatenated. + + .. attribute:: input_to_concatenatability + + A mapping from a :class:`FunctionDefinition`\ 's parameter to how it + should be concatenated. + + .. note:: + + A :class:`FunctionDefinition` typically has multiple valid + concatenability constraints. This class only records one of those valid + constraints. + """ + output_to_concatenatability: Mapping[str, Concatenatability] + input_to_concatenatability: Mapping[str, Concatenatability] + + def __str__(self) -> str: + outputs = [] + for name, concat in self.output_to_concatenatability.items(): + outputs.append(f"{name} => {concat}") + + inputs = [] + for name, concat in self.input_to_concatenatability.items(): + inputs.append(f"{name} => {concat}") + + output_str = "\n".join(outputs) + input_str = "\n".join(inputs) + + return (f"Outputs:\n--------\n{output_str}\n" + f"========\nInputs:\n-------\n{input_str}\n" + "========") + + +def _combine_named_result_accs_simple( + named_result_accs: Mapping[str, _InputConcatabilityGetterAcc] +) -> Tuple[FunctionConcatenability, ...]: + """ + Combines the concantenatability constraints of named results of a + :class:`FunctionDefinition` and returns a :class:`tuple` of the valid + *simple* concatenatable constraints (i.e., concatenation of all inputs/outputs + along the same axis). + """ + valid_concatenatabilities: List[FunctionConcatenability] = [] + + input_args = reduce( + frozenset.union, + [ + acc.seen_inputs + for acc in named_result_accs.values()], + frozenset()) + + candidate_concat_axes = reduce( + frozenset.union, + [ + frozenset(acc.input_concatability.keys()) + for acc in named_result_accs.values()], + frozenset()) + + # print(f"{candidate_concat_axes=}") + + for i_concat_axis in candidate_concat_axes: + # if isinstance(i_concat_axis, ConcatableAlongAxis) and i_concat_axis.axis == 0: + # for acc in named_result_accs.values(): + # for ary, concat in acc.input_concatability[i_concat_axis].items(): + # print(f"{type(ary).__name__=}, {ary.name=}, {ary.shape=}, {id(ary)=}, {concat=}") + # print("") + if ( + all( + i_concat_axis in acc.input_concatability + for acc in named_result_accs.values()) + and all( + ( + i_input_axis == i_concat_axis + or isinstance(i_input_axis, ConcatableIfConstant)) + for acc in named_result_accs.values() + for i_input_axis in ( + acc.input_concatability[i_concat_axis].values()))): + output_concats = {name: i_concat_axis for name in named_result_accs} + input_concats = {pl.name: i_concat_axis + for pl in input_args + if isinstance(pl, Placeholder)} + valid_concatenatabilities.append( + FunctionConcatenability(immutabledict(output_concats), + immutabledict(input_concats))) + + return valid_concatenatabilities + + +# FIXME: Find a more efficient way to do this. The number of candidates +# explodes when the function being concatenated has more than a few outputs +def _combine_named_result_accs_exhaustive( + named_result_accs: Mapping[str, _InputConcatabilityGetterAcc] +) -> Generator[ + FunctionConcatenability, + None, + None]: + """ + Combines the concantenatability constraints of named results of a + :class:`FunctionDefinition` and returns a :class:`tuple` of the valid + concatenatable constraints. + """ + potential_concatenatable_output_axes = itertools.product(*[ + [(name, concat) for concat in acc.input_concatability] + for name, acc in named_result_accs.items()]) + + for output_concats in potential_concatenatable_output_axes: + is_concatenatable = True + input_concatability: Dict[InputArgumentBase, Concatenatability] = {} + + for result_name, iresult_axis in output_concats: + for input_arg, i_input_axis in ( + named_result_accs[result_name] + .input_concatability[iresult_axis] + .items()): + try: + if input_concatability[input_arg] != i_input_axis: + is_concatenatable = False + break + except KeyError: + input_concatability[input_arg] = i_input_axis + + if not is_concatenatable: + break + + if is_concatenatable: + pl_concatabilities = {pl.name: concat + for pl, concat in input_concatability.items() + if isinstance(pl, Placeholder)} + yield FunctionConcatenability(immutabledict(output_concats), + immutabledict(pl_concatabilities)) + + +class _InputConcatabilityGetter( + CachedMapper[ArrayOrNames, Never, [ArrayOrNames, ...]]): + """ + Maps :class:`pytato.array.Array` expressions to + :class:`_InputConcatenatabilityGetterAcc` that summarizes constraints + induced on the concatenatability of the inputs of the expression by the + expression's concatenatability. + """ + def get_cache_key( + self, expr: ArrayOrNames, *exprs_from_other_calls: ArrayOrNames + ) -> tuple[ArrayOrNames, ...]: + return (expr, *exprs_from_other_calls) + + def _map_input_arg_base( + self, + expr: InputArgumentBase, + *exprs_from_other_calls: InputArgumentBase, + ) -> _InputConcatabilityGetterAcc: + input_concatenatability: Dict[Concatenatability, + Mapping[InputArgumentBase, + Concatenatability]] = {} + for idim in range(expr.ndim): + input_concatenatability[ConcatableAlongAxis(idim)] = immutabledict( + {expr: ConcatableAlongAxis(idim)}) + + input_concatenatability[ConcatableIfConstant()] = immutabledict( + {expr: ConcatableIfConstant()}) + + return _InputConcatabilityGetterAcc(frozenset([expr]), + immutabledict(input_concatenatability)) + + map_placeholder = _map_input_arg_base + map_data_wrapper = _map_input_arg_base + + def _map_index_lambda_like( + self, + expr: Array, + *exprs_from_other_calls: Array, + allow_indirect_addr: bool) -> _InputConcatabilityGetterAcc: + expr = to_index_lambda(expr) + exprs_from_other_calls = tuple( + to_index_lambda(ary) for ary in exprs_from_other_calls) + + input_accs = tuple( + self.rec( + expr.bindings[name], + *[ + ary.bindings[name] + for ary in exprs_from_other_calls]) + for name in sorted(expr.bindings.keys())) + expr_concat_to_input_concats: Dict[Concatenatability, + Tuple[Concatenatability, ...]] = {} + + for iaxis in range(expr.ndim): + for ary in (expr,) + exprs_from_other_calls: + # If the array has length 1 along this axis, the index may have been + # dropped from the scalar expression, in which case + # _get_binding_to_concatenatability will fail to determine the + # concatenatability. If that happens, we have to look at the other + # expressions in the hope that one of them has a non-1 length + if ary.shape[iaxis] == 1: + continue + try: + bnd_name_to_concat = _get_binding_to_concatenatability( + ary.expr, iaxis, allow_indirect_addr) + expr_concat_to_input_concats[ConcatableAlongAxis(iaxis)] = ( + tuple(concat + for _, concat in sorted(bnd_name_to_concat.items(), + key=lambda x: x[0])) + ) + except NonConcatableExpression: + # print(f"{iaxis=}") + # print(f"{ary.expr=}") + # print(f"{ary.shape=}") + break + + expr_concat_to_input_concats[ConcatableIfConstant()] = tuple( + ConcatableIfConstant() for _ in expr.bindings) + + return _combine_input_accs(input_accs, expr_concat_to_input_concats) + + map_index_lambda = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_einsum = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_basic_index = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_roll = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_stack = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_concatenate = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_axis_permutation = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_reshape = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + + map_contiguous_advanced_index = partialmethod(_map_index_lambda_like, + allow_indirect_addr=True) + map_non_contiguous_advanced_index = partialmethod(_map_index_lambda_like, + allow_indirect_addr=True) + + def map_named_call_result( + self, + expr: NamedCallResult, + *exprs_from_other_calls: NamedCallResult, + ) -> _InputConcatabilityGetterAcc: + raise NotImplementedError("nested functions aren't supported.") + + # FIXME: Update the code below to work after changing + # _InputConcatabilityGetter to look at all function calls instead of just + # the template call + assert isinstance(expr._container, Call) + valid_concatenatabilities = _get_valid_concatenatability_constraints_simple( + expr._container.function) + + expr_concat_possibilities = { + valid_concatenability.output_to_concatenatability[expr.name] + for valid_concatenability in valid_concatenatabilities + } + + input_concatenatabilities: Dict[Concatenatability, + Mapping[InputArgumentBase, + Concatenatability]] = {} + rec_bindings = {bnd_name: self.rec(binding) + for bnd_name, binding in expr._container.bindings.items()} + callee_acc = self.rec(expr._container.function.returns[expr.name]) + seen_inputs: Set[InputArgumentBase] = set() + + for seen_input in callee_acc.seen_inputs: + if isinstance(seen_input, Placeholder): + seen_inputs.update(rec_bindings[seen_input.name].seen_inputs) + elif isinstance(seen_input, (DataWrapper, SizeParam)): + seen_inputs.add(seen_input) + else: + raise NotImplementedError(type(seen_input)) + + for concat_possibility in expr_concat_possibilities: + caller_input_concatabilities: Dict[InputArgumentBase, + Concatenatability] = {} + + is_concat_possibility_valid = True + for callee_input_arg, callee_input_concat in ( + callee_acc.input_concatability[concat_possibility].items()): + caller_acc = rec_bindings[callee_input_arg.name] + if isinstance(callee_input_arg, Placeholder): + if callee_input_concat in caller_acc.input_concatability: + for caller_input_arg, caller_input_concat in ( + caller_acc + .input_concatability[callee_input_concat] + .items()): + try: + if (caller_input_concatabilities[caller_input_arg] + != caller_input_concat): + is_concat_possibility_valid = False + break + except KeyError: + caller_input_concatabilities[callee_input_arg] = ( + caller_input_concat) + if not is_concat_possibility_valid: + break + else: + is_concat_possibility_valid = False + break + elif isinstance(callee_input_arg, (DataWrapper, SizeParam)): + try: + if (caller_input_concatabilities[callee_input_arg] + != callee_input_concat): + is_concat_possibility_valid = False + break + except KeyError: + caller_input_concatabilities[callee_input_arg] = ( + callee_input_concat) + else: + raise NotImplementedError(type(callee_input_arg)) + + if is_concat_possibility_valid: + input_concatenatabilities[concat_possibility] = immutabledict( + caller_input_concatabilities) + + return _InputConcatabilityGetterAcc(frozenset(seen_inputs), + immutabledict(input_concatenatabilities)) + + def map_loopy_call_result( + self, + expr: "LoopyCallResult", + *exprs_from_other_calls: "LoopyCallResult", + ) -> _InputConcatabilityGetterAcc: + raise ValueError("Loopy Calls are illegal to concatenate. Maybe" + " rewrite the operation as array operations?") + + +def _verify_arrays_can_be_concated_along_axis( + arrays: Collection[Array], + fields_that_must_be_same: Collection[str], + iaxis: int) -> None: + """ + Performs some common checks if *arrays* from different function bodies can be + concatenated. + + .. attribute:: arrays + + Corresponding expressions from the function bodies for call-site that + are being checked for concatenation along *iaxis*. + """ + if not _have_same_axis_length_except(arrays, iaxis): + raise _InvalidConcatenatability("Axis lengths are incompatible.") + for field in fields_that_must_be_same: + if len({getattr(ary, field) for ary in arrays}) != 1: + raise _InvalidConcatenatability(f"Field '{field}' varies across calls.") + + +def _verify_arrays_same(arrays: Collection[Array]) -> None: + if len(set(arrays)) != 1: + raise _InvalidConcatenatability("Arrays are not the same.") + + +def _get_concatenated_shape(arrays: Collection[Array], iaxis: int) -> ShapeType: + # type-ignore-reason: mypy expects 'ary.shape[iaxis]' as 'int' since the + # 'start' is an 'int' + concatenated_axis_length = sum(ary.shape[iaxis] # type: ignore[misc] + for ary in arrays) + template_ary = next(iter(arrays)) + + return tuple(dim + if idim != iaxis + else concatenated_axis_length + for idim, dim in enumerate(template_ary.shape) + ) + + +class _ConcatabilityCollector(CachedWalkMapper): + def __init__( + self, + current_stack: Tuple[Call, ...], + _visited_functions: set[Any] | None = None + ) -> None: + self.ary_to_concatenatability: Dict[ArrayOnStackT, Concatenatability] = {} + self.current_stack = current_stack + self.call_sites_on_hold: Set[Call] = set() + super().__init__(_visited_functions=_visited_functions) + + # type-ignore-reason: CachedWalkMaper takes variadic `*args, **kwargs`. + def get_cache_key(self, # type: ignore[override] + expr: ArrayOrNames, + *args: Any, + ) -> Tuple[ArrayOrNames, Any]: + return (expr, args) + + # type-ignore-reason: CachedWalkMaper takes variadic `*args, **kwargs`. + def get_function_definition_cache_key( + self, # type: ignore[override] + expr: FunctionDefinition, + *args: Any, + ) -> tuple[ArrayOrNames, Any]: + return (expr, args) + + def _record_concatability(self, expr: Array, + concatenatability: Concatenatability, + ) -> None: + key = (self.current_stack, expr) + assert key not in self.ary_to_concatenatability + self.ary_to_concatenatability[key] = concatenatability + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + raise AssertionError("Control should not reach here." + " Call clone_with_new_call_on_stack instead.") + + def clone_with_new_call_on_stack(self, expr: Call) -> Self: + # type-ignore-reason: Mapper class does not define these attributes. + return type(self)( # type: ignore[call-arg] + self.current_stack + (expr,), # type: ignore[attr-defined] + _visited_functions=self._visited_functions + ) + + def _map_input_arg_base(self, + expr: InputArgumentBase, + concatenatability: Concatenatability, + exprs_from_other_calls: Tuple[Array, ...], + ) -> None: + if isinstance(concatenatability, ConcatableIfConstant): + _verify_arrays_same((expr,) + exprs_from_other_calls) + elif isinstance(concatenatability, ConcatableAlongAxis): + # FIXME: Probably needs some extra handling for broadcastable arrays + _verify_arrays_can_be_concated_along_axis( + (expr,) + exprs_from_other_calls, + ["dtype", "name"], + concatenatability.axis) + else: + raise NotImplementedError(type(concatenatability)) + + self._record_concatability(expr, concatenatability) + + map_placeholder = _map_input_arg_base # type: ignore[assignment] + map_data_wrapper = _map_input_arg_base # type: ignore[assignment] + + def _map_index_lambda_like(self, + expr: Array, + concatenatability: Concatenatability, + exprs_from_other_calls: Tuple[Array, ...], + allow_indirect_addr: bool, + ) -> None: + self._record_concatability(expr, concatenatability) + + idx_lambda = to_index_lambda(expr) + idx_lambdas_from_other_calls = tuple(to_index_lambda(ary) + for ary in exprs_from_other_calls) + + if isinstance(concatenatability, ConcatableIfConstant): + _verify_arrays_same((idx_lambda,) + idx_lambdas_from_other_calls) + for bnd_name in idx_lambda.bindings: + self.rec( + idx_lambda.bindings[bnd_name], concatenatability, + tuple( + ary.bindings[bnd_name] + for ary in idx_lambdas_from_other_calls)) + elif isinstance(concatenatability, ConcatableAlongAxis): + _verify_arrays_can_be_concated_along_axis( + (idx_lambda, ) + idx_lambdas_from_other_calls, + ["dtype"], + concatenatability.axis) + if len({ + ary.expr + for ary in (idx_lambda,) + idx_lambdas_from_other_calls + if ary.shape[concatenatability.axis] != 1}) != 1: + raise _InvalidConcatenatability( + "Cannot concatenate the calls; required fields are not the same.") + bnd_name_to_concat = None + for ary in (idx_lambda,) + idx_lambdas_from_other_calls: + if ary.shape[concatenatability.axis] > 1: + bnd_name_to_concat = _get_binding_to_concatenatability( + ary.expr, concatenatability.axis, allow_indirect_addr) + break + if bnd_name_to_concat is None: + bnd_name_to_concat = _get_binding_to_concatenatability( + idx_lambda.expr, concatenatability.axis, allow_indirect_addr) + for bnd_name, bnd_concat in bnd_name_to_concat.items(): + self.rec(idx_lambda.bindings[bnd_name], bnd_concat, + tuple(ary.bindings[bnd_name] + for ary in idx_lambdas_from_other_calls)) + else: + raise NotImplementedError(type(concatenatability)) + + map_index_lambda = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_einsum = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_basic_index = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_roll = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_stack = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_concatenate = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_axis_permutation = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_reshape = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + + map_contiguous_advanced_index = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=True) + map_non_contiguous_advanced_index = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=True) + + # type-ignore-reason: CachedWalkMapper.map_call takes in variadic args, kwargs + def map_call(self, # type: ignore[override] + expr: Call, + exprs_from_other_calls: Tuple[Call, ...]) -> None: + if not all( + (self.current_stack, named_result) in self.ary_to_concatenatability + for named_result in expr.values()): + self.call_sites_on_hold.add(expr) + else: + self.call_sites_on_hold -= {expr} + # FIXME The code below bypasses caching of function definitions + new_mapper = self.clone_with_new_call_on_stack(expr) + for name, val_in_callee in expr.function.returns.items(): + new_mapper(val_in_callee, + self.ary_to_concatenatability[(self.current_stack, + expr[name])], + tuple(other_call.function.returns[name] + for other_call in exprs_from_other_calls) + ) + + if new_mapper.call_sites_on_hold: + raise NotImplementedError("Call sites that do not all use all" + " the returned values not yet" + " supported for concatenation.") + + for ary, concat in new_mapper.ary_to_concatenatability.items(): + assert ary not in self.ary_to_concatenatability + self.ary_to_concatenatability[ary] = concat + + for name, binding in expr.bindings.items(): + if not isinstance(binding, Array): + continue + concat = ( + new_mapper + .ary_to_concatenatability[(self.current_stack + (expr,), + expr.function.get_placeholder(name))] + ) + self.rec(binding, + concat, + tuple(other_call.bindings[name] + for other_call in exprs_from_other_calls)) + + # type-ignore-reason: CachedWalkMapper's method takes in variadic args, kwargs + def map_named_call_result(self, expr: NamedCallResult, # type: ignore[override] + concatenatability: Concatenatability, + exprs_from_other_calls: Tuple[Array, ...], + ) -> None: + self._record_concatability(expr, concatenatability) + if any(not isinstance(ary, NamedCallResult) + for ary in exprs_from_other_calls): + raise _InvalidConcatenatability() + + # type-ignore-reason: mypy does not respect the conditional which + # asserts that all arrays in `exprs_from_other_calls` are + # NamedCallResult. + self.rec(expr._container, + tuple(ary._container # type: ignore[attr-defined] + for ary in exprs_from_other_calls) + ) + + def map_loopy_call_result(self, expr: "LoopyCallResult" + ) -> None: + raise ValueError("Loopy Calls are illegal to concatenate. Maybe" + " rewrite the operation as array operations?") + + +# Memoize the creation of concatenated input arrays to avoid copies +class _InputConcatenator: + def __init__(self, inherit_axes: bool): + self.inherit_axes = inherit_axes + + @memoize_method + def __call__(self, arrays, axis): + if self.inherit_axes: + concat_axis_tag = UseInputAxis(0, axis) + else: + concat_axis_tag = ConcatenatedCallInputConcatAxisTag() + return concatenate( + arrays, + axis + ).with_tagged_axis(axis, frozenset({concat_axis_tag})).tagged( + ImplStored()) + + +# Memoize the creation of sliced output arrays to avoid copies +class _OutputSlicer: + def __init__(self, inherit_axes: bool): + self.inherit_axes = inherit_axes + + @memoize_method + def _get_slice( + self, + ary: Array, + axis: int, + start_idx: ShapeComponent, + end_idx: ShapeComponent): + indices = [slice(None) for i in range(ary.ndim)] + indices[axis] = slice(start_idx, end_idx) + if self.inherit_axes: + slice_axis_tag = UseInputAxis(None, axis) + else: + slice_axis_tag = ConcatenatedCallOutputSliceAxisTag() + sliced_ary = ary[tuple(indices)].with_tagged_axis( + axis, frozenset({slice_axis_tag})).tagged(ImplStored()) + assert isinstance(sliced_ary, BasicIndex) + return sliced_ary + + def __call__(self, ary, axis, slice_sizes): + start_indices: List[ShapeComponent] = [] + end_indices: List[ShapeComponent] = [] + if len(slice_sizes) > 0: + start_indices.append(0) + end_indices.append(slice_sizes[0]) + for islice in range(1, len(slice_sizes)): + start_indices.append(end_indices[-1]) + end_indices.append(end_indices[-1] + slice_sizes[islice]) + return [ + self._get_slice(ary, axis, start_idx, end_idx) + for start_idx, end_idx in zip(start_indices, end_indices)] + + +class _FunctionConcatenator(TransformMapperWithExtraArgs[Tuple[Array, ...]]): + def __init__(self, + current_stack: Tuple[Call, ...], + input_concatenator: _InputConcatenator, + ary_to_concatenatability: Mapping[ArrayOnStackT, Concatenatability], + _cache: TransformMapperCache[ + ArrayOrNames, [Tuple[Array, ...]]] | None = None, + _function_cache: TransformMapperCache[ + FunctionDefinition, [Tuple[Array, ...]]] | None = None + ) -> None: + super().__init__(_cache=_cache, _function_cache=_function_cache) + self.current_stack = current_stack + self.input_concatenator = input_concatenator + self.ary_to_concatenatability = ary_to_concatenatability + + def get_cache_key( + self, expr: ArrayOrNames, exprs_from_other_calls: tuple[Array, ...] + ) -> tuple[ArrayOrNames, ...]: + return (expr, *exprs_from_other_calls) + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + raise AssertionError("Control should not reach here." + " Call clone_with_new_call_on_stack instead.") + + def clone_with_new_call_on_stack(self, expr: Call) -> Self: + return type(self)( + self.current_stack + (expr,), + self.input_concatenator, + self.ary_to_concatenatability, + _function_cache=self._function_cache + ) + + def _get_concatenatability(self, expr: Array) -> Concatenatability: + return self.ary_to_concatenatability[(self.current_stack, expr)] + + def map_placeholder(self, + expr: Placeholder, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + concat = self._get_concatenatability(expr) + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + new_shape = _get_concatenated_shape( + (expr,) + exprs_from_other_calls, concat.axis) + return Placeholder(name=expr.name, + dtype=expr.dtype, + shape=new_shape, + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + def map_data_wrapper(self, + expr: DataWrapper, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + concat = self._get_concatenatability(expr) + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + return self.input_concatenator( + (expr,) + exprs_from_other_calls, concat.axis) + else: + raise NotImplementedError(type(concat)) + + def map_index_lambda(self, + expr: IndexLambda, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + concat = self._get_concatenatability(expr) + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, IndexLambda) + for ary in exprs_from_other_calls) + + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are IndexLambda. + new_bindings = { + bnd_name: self.rec( + subexpr, + tuple(ary.bindings[bnd_name] # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + for bnd_name, subexpr in expr.bindings.items() + } + new_shape = _get_concatenated_shape((expr,) + exprs_from_other_calls, + concat.axis) + return IndexLambda(expr=expr.expr, + shape=new_shape, + dtype=expr.dtype, + bindings=immutabledict(new_bindings), + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + def map_einsum(self, expr: Einsum, + exprs_from_other_calls: Tuple[Array, ...]) -> Array: + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, Einsum) for ary in exprs_from_other_calls) + + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are Einsum. + new_args = [self.rec(arg, + tuple(ary.args[iarg] # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + for iarg, arg in enumerate(expr.args)] + + return Einsum(expr.access_descriptors, + tuple(new_args), + expr.redn_axis_to_redn_descr, + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + def _map_index_base(self, expr: IndexBase, + exprs_from_other_calls: Tuple[Array, ...]) -> Array: + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, IndexBase) for ary in exprs_from_other_calls) + + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are IndexBase. + new_indices = [ + self.rec(idx, + tuple(ary.indices[i_idx] # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + if isinstance(idx, Array) + else idx + for i_idx, idx in enumerate(expr.indices) + ] + new_array = self.rec(expr.array, + tuple(ary.array # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + + return type(expr)(array=new_array, + indices=tuple(new_indices), + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + map_contiguous_advanced_index = _map_index_base + map_non_contiguous_advanced_index = _map_index_base + map_basic_index = _map_index_base + + def map_roll(self, + expr: Roll, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert concat.axis != expr.axis + assert all(isinstance(ary, Roll) for ary in exprs_from_other_calls) + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are Roll. + return Roll(self.rec(expr.array, + tuple(ary.array # type: ignore[attr-defined] + for ary in exprs_from_other_calls)), + shift=expr.shift, + axis=expr.axis, + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + def map_stack(self, expr: Stack, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, Stack) for ary in exprs_from_other_calls) + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are Stack. + if any(len(ary.arrays) != len(expr.arrays) # type: ignore[attr-defined] + for ary in exprs_from_other_calls): + raise ValueError("Cannot concatenate stack expressions" + " with different number of arrays.") + + new_arrays = tuple( + self.rec(array, + tuple(subexpr.arrays[iarray] # type: ignore[attr-defined] + for subexpr in exprs_from_other_calls) + ) + for iarray, array in enumerate(expr.arrays)) + + return Stack(new_arrays, + expr.axis, + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + def map_concatenate(self, expr: Concatenate, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, Concatenate) + for ary in exprs_from_other_calls) + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are Concatenate. + if any(len(ary.arrays) != len(expr.arrays) # type: ignore[attr-defined] + for ary in exprs_from_other_calls): + raise ValueError("Cannot concatenate concatenate-expressions" + " with different number of arrays.") + + new_arrays = tuple( + self.rec(array, + tuple(subexpr.arrays[iarray] # type: ignore[attr-defined] + for subexpr in exprs_from_other_calls) + ) + for iarray, array in enumerate(expr.arrays) + ) + + return Concatenate(new_arrays, + expr.axis, + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + def map_axis_permutation(self, expr: AxisPermutation, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, AxisPermutation) + for ary in exprs_from_other_calls) + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are AxisPermutation. + new_array = self.rec(expr.array, + tuple(ary.array # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + return AxisPermutation(new_array, + expr.axis_permutation, + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + def map_reshape(self, expr: Reshape, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + new_newshape = _get_concatenated_shape( + (expr,) + exprs_from_other_calls, concat.axis) + + assert all(isinstance(ary, Reshape) for ary in exprs_from_other_calls) + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are Reshape. + new_array = self.rec(expr.array, + tuple(ary.array # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + return Reshape(new_array, + new_newshape, + expr.order, + tags=expr.tags, + axes=expr.axes, + non_equality_tags=expr.non_equality_tags) + else: + raise NotImplementedError(type(concat)) + + def map_function_definition( + self, + expr: FunctionDefinition, + exprs_from_other_calls: Tuple[FunctionDefinition, ...] + ) -> FunctionDefinition: + # No clone here because we're cloning in map_call instead + new_returns = { + name: self.rec( + ret, + tuple( + other_expr.returns[name] + for other_expr in exprs_from_other_calls)) + for name, ret in expr.returns.items()} + if all( + new_ret is ret + for ret, new_ret in zip( + expr.returns.values(), + new_returns.values())): + return expr + else: + return attrs.evolve(expr, returns=immutabledict(new_returns)) + + def map_call(self, expr: Call, other_callsites: Tuple[Call, ...]) -> Call: + new_mapper = self.clone_with_new_call_on_stack(expr) + new_function = new_mapper.rec_function_definition( + expr.function, + tuple(other_call.function for other_call in other_callsites)) + new_bindings = {name: ( + self.rec( + bnd, tuple( + callsite.bindings[name] + for callsite in other_callsites)) + if isinstance(bnd, Array) + else bnd) + for name, bnd in expr.bindings.items()} + if ( + new_function is expr.function + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values()))): + return expr + else: + return Call(new_function, immutabledict(new_bindings), tags=expr.tags) + + def map_named_call_result(self, + expr: NamedCallResult, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, NamedCallResult) + for ary in exprs_from_other_calls) + assert isinstance(expr._container, Call) + new_call = self.rec( + expr._container, + tuple(ary._container # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + return new_call[expr.name] + else: + raise NotImplementedError(type(concat)) + + def map_loopy_call_result(self, expr: "LoopyCallResult", + exprs_from_other_calls: Tuple[Array, ...], + ) -> _InputConcatabilityGetterAcc: + raise ValueError("Loopy Calls are illegal to concatenate. Maybe" + " rewrite the operation as array operations?") + + +@memoize_on_first_arg +def _get_valid_concatenatability_constraints_simple( + template_call: Call, *other_calls: Call) -> Tuple[FunctionConcatenability]: + template_fn = template_call.function + mapper = _InputConcatabilityGetter() + output_accs = { + name: mapper( + *[cs.function.returns[name] for cs in (template_call,) + other_calls]) + for name in template_fn.returns} + + return _combine_named_result_accs_simple(output_accs) + + +@memoize_on_first_arg +def _get_valid_concatenatability_constraints_exhaustive( + fn: FunctionDefinition) -> Generator[ + FunctionConcatenability, + None, + None]: + mapper = _InputConcatabilityGetter() + output_accs = {name: mapper(output) + for name, output in fn.returns.items()} + + yield from _combine_named_result_accs_exhaustive(output_accs) + + +def _get_ary_to_concatenatabilities(call_sites: Sequence[Call], + ) -> Generator[Mapping[ArrayOnStackT, + Concatenatability], + None, + None]: + """ + Generates a :class:`Concatenatability` criterion for each array in the + expression graph of *call_sites*'s function body if they traverse identical + function bodies. + """ + fn_concatenatabilities = \ + _get_valid_concatenatability_constraints_simple(*call_sites) + + # select a template call site to start the traversal. + template_call, *other_calls = call_sites + template_fn = template_call.function + fid = next(iter(template_fn.tags_of_type(FunctionIdentifier))) + + concat_idx_to_err_msg = {} + + for iconcat, fn_concatenatability in enumerate(fn_concatenatabilities): + collector = _ConcatabilityCollector(current_stack=()) + + try: + # verify the constraints on parameters are satisfied + for name, input_concat in (fn_concatenatability + .input_to_concatenatability + .items()): + try: + if isinstance(input_concat, ConcatableIfConstant): + _verify_arrays_same([cs.bindings[name] for cs in call_sites]) + elif isinstance(input_concat, ConcatableAlongAxis): + _verify_arrays_can_be_concated_along_axis( + [cs.bindings[name] for cs in call_sites], + [], + input_concat.axis) + else: + raise NotImplementedError(type(input_concat)) + except _InvalidConcatenatability as e: + raise _InvalidConcatenatability( + f"Binding for input {name} is not concatenatable. {str(e)}") + + # verify the constraints on function bodies are satisfied + for name, output_concat in (fn_concatenatability + .output_to_concatenatability + .items()): + try: + collector(template_call.function.returns[name], + output_concat, + tuple(other_call.function.returns[name] + for other_call in other_calls)) + except _InvalidConcatenatability as e: + raise _InvalidConcatenatability( + f"Function output {name} is not concatenatable. {str(e)}") + except _InvalidConcatenatability as e: + concat_idx_to_err_msg[iconcat] = str(e) + else: + if collector.call_sites_on_hold: + raise NotImplementedError("Expressions that use part of" + " function's returned values are not" + " yet supported.") + + logger.info( + f"Found a valid concatenatability for function with ID '{fid}' --\n" + f"{fn_concatenatability}") + + yield immutabledict(collector.ary_to_concatenatability) + + log_str = ( + f"No more valid concatenatabilities for function with ID '{fid}'. " + "Unsuitable candidates:\n") + for iconcat, fn_concatenatability in enumerate(fn_concatenatabilities): + try: + err_msg = concat_idx_to_err_msg[iconcat] + except KeyError: + continue + log_str += f"Candidate:\n{fn_concatenatability}\n" + log_str += f"Error: {concat_idx_to_err_msg[iconcat]}\n\n" + logger.info(log_str) + + +def _get_replacement_map_post_concatenating( + call_sites: Sequence[Call], + used_call_results: frozenset(NamedCallResult), + input_concatenator: _InputConcatenator, + output_slicer: _OutputSlicer) -> Mapping[NamedCallResult, Array]: + """ + .. note:: + + We require *call_sites* to be ordered to determine the concatenation + order. + """ + assert call_sites, "Empty `call_sites`." + + ary_to_concatenatabilities = _get_ary_to_concatenatabilities(call_sites) + + template_call_site, *other_call_sites = call_sites + template_function = template_call_site.function + fid = next(iter(template_function.tags_of_type(FunctionIdentifier))) + + try: + ary_to_concatenatability = next(ary_to_concatenatabilities) + except StopIteration: + raise ValueError( + f"No valid concatenatibilities found for function with ID '{fid}'.") + else: + if __debug__: + try: + next(ary_to_concatenatabilities) + except StopIteration: + # unique concatenatibility + pass + else: + from warnings import warn + # TODO: Take some input from the user to resolve this ambiguity. + warn( + "Multiple concatenation possibilities found for function with " + f"ID '{fid}'. This may lead to non-deterministic transformed " + "expression graphs.") + + # {{{ actually perform the concatenation + + template_returns = template_function.returns + template_bindings = template_call_site.bindings + + function_concatenator = _FunctionConcatenator( + current_stack=(), input_concatenator=input_concatenator, + ary_to_concatenatability=ary_to_concatenatability) + + if __debug__: + # FIXME: We may be able to handle this without burdening the user + # See https://github.com/inducer/pytato/issues/559 + from collections import defaultdict + param_to_used_calls = defaultdict(set) + for output_name in template_call_site.keys(): + for csite in call_sites: + call_result = csite[output_name] + if call_result in used_call_results: + ret = csite.function.returns[output_name] + used_params = ( + { + expr.name + for expr in InputGatherer()(ret)} + & csite.function.parameters) + for name in used_params: + param_to_used_calls[name] |= {csite} + for name, used_calls in param_to_used_calls.items(): + if used_calls != set(call_sites): + from warnings import warn + warn( + f"DAG output does not depend on parameter '{name}' for some " + f"calls to function with ID '{fid}'. Concatenation will prevent " + "these unused inputs from being removed from the DAG when the " + "function is inlined. This may lead to unnecessary computation.") + + # new_returns: concatenated function body + new_returns: Dict[str, Array] = {} + for output_name in template_call_site.keys(): + new_returns[output_name] = function_concatenator( + template_returns[output_name], + tuple(csite.function.returns[output_name] + for csite in other_call_sites)) + + # }}} + + # construct new function body + if any( + new_returns[output_name] is not template_returns[output_name] + for output_name in template_returns): + new_function = FunctionDefinition( + template_call_site.function.parameters, + template_call_site.function.return_type, + immutabledict(new_returns), + tags=template_call_site.function.tags) + else: + new_function = template_call_site.function + + result: Dict[NamedCallResult, Array] = {} + + new_call_bindings: Dict[str, Array] = {} + + # construct new bindings + for param_name in template_bindings: + param_placeholder = template_call_site.function.get_placeholder(param_name) + param_concat = ary_to_concatenatability[((), param_placeholder)] + if isinstance(param_concat, ConcatableAlongAxis): + param_bindings = tuple([ + csite.bindings[param_name] + for csite in call_sites]) + new_binding = input_concatenator( + param_bindings, + param_concat.axis) + elif isinstance(param_concat, ConcatableIfConstant): + new_binding = template_bindings[param_name] + else: + raise NotImplementedError(type(param_concat)) + new_call_bindings[param_name] = new_binding + + # construct new call + if ( + new_function is not template_call_site.function + or any( + new_call_bindings[param_name] is not template_bindings[param_name] + for param_name in template_bindings)): + new_call = Call( + function=new_function, + bindings=immutabledict(new_call_bindings), + tags=template_call_site.tags) + else: + new_call = template_call_site + + # slice into new_call's outputs to replace the old expressions. + for output_name, output_ary in (template_call_site + .function + .returns + .items()): + concat = ary_to_concatenatability[((), output_ary)] + new_return = new_call[output_name] + if isinstance(concat, ConcatableIfConstant): + # FIXME: Does it make sense to not concatenate if some arguments are + # ConcatableIfConstant and some are ConcatableAlongAxis? Seems like that + # would cause problems... + for cs in call_sites: + result[cs[output_name]] = new_return + elif isinstance(concat, ConcatableAlongAxis): + slice_sizes = [ + cs[output_name].shape[concat.axis] + for cs in call_sites] + output_slices = output_slicer(new_return, concat.axis, slice_sizes) + for cs, output_slice in zip(call_sites, output_slices): + result[cs[output_name]] = output_slice + else: + raise NotImplementedError(type(concat)) + + return immutabledict(result) + + +def concatenate_calls(expr: ArrayOrNames, + call_site_filter: Callable[[CallSiteLocation], bool], + *, + inherit_axes: bool = False, + warn_if_no_calls: bool = True, + err_if_no_calls: bool = False, + ignore_tag_types: frozenset(type) | None = None, + ) -> ArrayOrNames: + r""" + Returns a copy of *expr* after concatenating all call-sites ``C`` such that + ``call_site_filter(C) is True``. + + :arg call_site_filter: A callable to select which instances of + :class:`~pytato.function.Call`\ s must be concatenated. + """ + if ignore_tag_types is None: + ignore_tag_types: frozenset(type) = frozenset() + + call_site_collector = CallSiteDependencyCollector(stack=()) + + all_call_sites = call_site_collector(expr) + filtered_call_sites = {cs + for cs in all_call_sites + if call_site_filter(cs)} + + function_ids = { + next(iter(cs.call.function.tags_of_type(FunctionIdentifier))) + for cs in filtered_call_sites} + + # Input concatenator needs to be set up outside of the loop in order to prevent + # creating duplicates; probably not strictly necessary for output slicer + input_concatenator = _InputConcatenator(inherit_axes=inherit_axes) + output_slicer = _OutputSlicer(inherit_axes=inherit_axes) + + result = expr + + for fid in function_ids: + call_site_dep_collector = CallSiteDependencyCollector(stack=()) + call_site_dep_collector(result) + + call_site_to_dep_call_sites = \ + call_site_dep_collector.call_site_to_dep_call_sites + + unbatched_call_sites: Set[CallSiteLocation] = { + cs for cs in call_site_to_dep_call_sites.keys() + if call_site_filter(cs) and fid in cs.call.function.tags} + + for cs in unbatched_call_sites: + for ret in cs.call.function.returns.values(): + nested_calls = collect_nodes_of_type(ret, Call) + if nested_calls: + raise NotImplementedError( + "Concatenation of nested calls is not yet supported.") + + call_site_batches: List[FrozenSet[CallSiteLocation]] = [] + + replacement_map: Dict[ + Tuple[NamedCallResult, Tuple[Call, ...]], + Array] = {} + + used_call_results = collect_nodes_of_type(result, NamedCallResult) + + while unbatched_call_sites: + ready_call_sites = frozenset({ + cs for cs in unbatched_call_sites + if not call_site_to_dep_call_sites[cs] & unbatched_call_sites}) + + from mpi4py import MPI + rank = MPI.COMM_WORLD.rank + + # if fid.identifier == "_make_fluid_state": + # print(f"{rank}: {len(ready_call_sites)=}") + + if not ready_call_sites: + raise ValueError("Found cycle in call site dependency graph.") + + template_call_site = next(iter(ready_call_sites)) + template_fn = template_call_site.call.function + + from pytato.equality import SimilarityComparer + similarity_comparer = SimilarityComparer( + ignore_tag_types=ignore_tag_types) + # err_on_not_similar=(fid.identifier == "_make_fluid_state")) + + # if fid.identifier == "_make_fluid_state": + # for cs in ready_call_sites: + # same_outputs = ( + # frozenset(cs.call.function.returns.keys()) + # == frozenset(template_fn.returns.keys())) + # similar = all( + # similarity_comparer( + # cs.call.function.returns[name], + # template_fn.returns[name]) + # for name in template_fn.returns) + # same_stack = (cs.stack == template_call_site.stack) + # print(f"{rank}: {same_outputs=}, {similar=}, {same_stack=}") + # # if not similar: + # # for name in template_fn.returns: + # # from pytato.analysis import get_num_nodes + # # nnodes_template = get_num_nodes(template_fn.returns[name]) + # # nnodes_other = get_num_nodes(cs.call.function.returns[name]) + # # print(f"{rank}: {name=}, {nnodes_template=}, {nnodes_other=}") + + similar_call_sites = frozenset({ + cs for cs in ready_call_sites + if ( + ( + frozenset(cs.call.function.returns.keys()) + == frozenset(template_fn.returns.keys())) + and all( + similarity_comparer( + cs.call.function.returns[name], + template_fn.returns[name]) + for name in template_fn.returns) + and cs.stack == template_call_site.stack)}) + + # if fid.identifier == "_make_fluid_state": + # print(f"{rank}: {len(similar_call_sites)=}") + + if not similar_call_sites: + raise ValueError("Failed to find similar call sites to concatenate.") + + def get_axis0_len(cs): + first_out_name = next(iter(cs.call.function.returns.keys())) + axis0_len = cs.call[first_out_name].shape[0] + assert all( + cs.call[name].shape[0] == axis0_len + for name in cs.call.function.returns) + return axis0_len + + batch_call_sites = sorted(similar_call_sites, key=get_axis0_len) + + call_site_batches.append(batch_call_sites) + unbatched_call_sites -= frozenset(batch_call_sites) + + # FIXME: this doesn't work; need to create/execute batches one at a time, + # then repeat the steps above to collect the updated call sites after + # concatenating the previous batch + for ibatch, call_sites in enumerate(call_site_batches): + from mpi4py import MPI + rank = MPI.COMM_WORLD.rank + + template_fn = next(iter(call_sites)).call.function + + # FIXME: Can't currently call get_num_nodes on a function definition + from pytato.array import make_dict_of_named_arrays + from pytato.analysis import get_num_nodes + fn_body = make_dict_of_named_arrays(template_fn.returns) + nnodes = get_num_nodes(fn_body) + + print( + f"{rank}: Concatenating function '{fid}' (batch {ibatch+1} of " + f"{len(call_site_batches)}: {nnodes} nodes, {len(call_sites)} " + "call sites).") + + if len(call_sites) <= 1: + if err_if_no_calls: + raise ValueError( + f"Not enough calls to concatenate function with ID '{fid}'.") + elif warn_if_no_calls: + from warnings import warn + warn( + f"Not enough calls to concatenate function with ID '{fid}'.", + stacklevel=2) + else: + pass + continue + + old_expr_to_new_expr_map = _get_replacement_map_post_concatenating( + [cs.call for cs in call_sites], + used_call_results, + input_concatenator=input_concatenator, + output_slicer=output_slicer) + + stack, = {cs.stack for cs in call_sites} + + replacement_map.update({ + (old_expr, stack): new_expr + for old_expr, new_expr in old_expr_to_new_expr_map.items()}) + + # FIXME: Still getting some duplicated `Concatenate`s, not sure why + dedup = Deduplicator() + result = dedup(result) + replacement_map = { + old_expr_and_stack: dedup(new_expr) + for old_expr_and_stack, new_expr in replacement_map.items()} + + result = _NamedCallResultReplacerPostConcatenate( + replacement_map=replacement_map, + current_stack=())(result) + + assert isinstance(result, (Array, AbstractResultWithNamedArrays)) + return result + +# }}} + # vim:foldmethod=marker diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 394f0904b..2898b0cef 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -148,13 +148,44 @@ def _get_reshaped_indices( assert product(new_shape) == 1 return () - if expr.order not in ["C", "F"]: + if order not in ["C", "F"]: raise NotImplementedError("Order expected to be 'C' or 'F'", - f" found {expr.order}") + f" found {order}") + + non1_shape = [] + for axis_len in new_shape: + assert isinstance(axis_len, INT_CLASSES) + if axis_len > 1: + non1_shape.append(axis_len) + non1_shape = tuple(non1_shape) + + old_non1_shape = [] + for axis_len in old_shape: + assert isinstance(axis_len, INT_CLASSES) + if axis_len > 1: + old_non1_shape.append(axis_len) + old_non1_shape = tuple(old_non1_shape) + + if non1_shape == old_non1_shape: + non1_axes = tuple( + iaxis for iaxis in range(len(new_shape)) + if new_shape[iaxis] > 1) + old_non1_axes = tuple( + iaxis for iaxis in range(len(old_shape)) + if old_shape[iaxis] > 1) + old_iaxis_to_iaxis = { + old_iaxis: iaxis + for old_iaxis, iaxis in zip( + old_non1_axes, non1_axes)} + return tuple( + prim.Variable(f"_{old_iaxis_to_iaxis[old_iaxis]}") + if old_iaxis in old_iaxis_to_iaxis + else 0 + for old_iaxis in range(len(old_shape))) - if expr.order == "C": + if order == "C": newstrides: list[IntegralT] = [1] # reshaped array strides - for new_axis_len in reversed(expr.shape[1:]): + for new_axis_len in reversed(new_shape[1:]): assert isinstance(new_axis_len, INT_CLASSES) newstrides.insert(0, newstrides[0]*new_axis_len) @@ -162,20 +193,20 @@ def _get_reshaped_indices( for i, stride in enumerate(newstrides)) oldstrides: list[IntegralT] = [1] # input array strides - for axis_len in reversed(expr.array.shape[1:]): + for axis_len in reversed(old_shape[1:]): assert isinstance(axis_len, INT_CLASSES) oldstrides.insert(0, oldstrides[0]*axis_len) - assert isinstance(expr.array.shape[-1], INT_CLASSES) - oldsizetills = [expr.array.shape[-1]] # input array size + assert isinstance(old_shape[-1], INT_CLASSES) + oldsizetills = [old_shape[-1]] # input array size # till for axes idx - for old_axis_len in reversed(expr.array.shape[:-1]): + for old_axis_len in reversed(old_shape[:-1]): assert isinstance(old_axis_len, INT_CLASSES) oldsizetills.insert(0, oldsizetills[0]*old_axis_len) else: newstrides: list[IntegralT] = [1] # reshaped array strides - for new_axis_len in expr.shape[:-1]: + for new_axis_len in new_shape[:-1]: assert isinstance(new_axis_len, INT_CLASSES) newstrides.append(newstrides[-1]*new_axis_len) @@ -183,13 +214,13 @@ def _get_reshaped_indices( for i, stride in enumerate(newstrides)) oldstrides: list[IntegralT] = [1] # input array strides - for axis_len in expr.array.shape[:-1]: + for axis_len in old_shape[:-1]: assert isinstance(axis_len, INT_CLASSES) oldstrides.append(oldstrides[-1]*axis_len) - assert isinstance(expr.array.shape[0], INT_CLASSES) - oldsizetills = [expr.array.shape[0]] # input array size till for axes idx - for old_axis_len in expr.array.shape[1:]: + assert isinstance(old_shape[0], INT_CLASSES) + oldsizetills = [old_shape[0]] # input array size till for axes idx + for old_axis_len in old_shape[1:]: assert isinstance(old_axis_len, INT_CLASSES) oldsizetills.append(oldsizetills[-1]*old_axis_len) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index de3625978..e3e880cf7 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -70,6 +70,7 @@ IndexLambda, InputArgumentBase, NamedArray, + NormalizedSlice, Reshape, Stack, ) @@ -79,8 +80,10 @@ IDX_LAMBDA_AXIS_INDEX, CombineMapper, ) +from pytato.tags import UseInputAxis from pytato.transform import ArrayOrNames, CopyMapper, Mapper, TransformMapperCache from pytato.transform.lower_to_index_lambda import to_index_lambda +from pytato.utils import are_shape_components_equal logger = logging.getLogger(__name__) @@ -326,7 +329,26 @@ def map_stack(self, expr: Stack) -> None: def map_concatenate(self, expr: Concatenate) -> None: for ary in expr.arrays: self.rec(ary) - self.add_equations_using_index_lambda_version_of_expr(expr) + # FIXME: Figure out how to integrate the UseInputAxis stuff into + # add_equations_using_index_lambda_version_of_expr + # self.add_equations_using_index_lambda_version_of_expr(expr) + for ary in expr.arrays: + assert ary.ndim == expr.ndim + for iaxis in range(expr.ndim): + if iaxis == expr.axis: + use_input_axis_tags = expr.axes[iaxis].tags_of_type( + UseInputAxis) + if use_input_axis_tags: + tag, = use_input_axis_tags + self.record_equation( + self.get_var_for_axis(expr.arrays[tag.key], tag.axis), + self.get_var_for_axis(expr, iaxis)) + else: + # non-concatenated axes share the dimensions. + self.record_equation( + self.get_var_for_axis(ary, iaxis), + self.get_var_for_axis(expr, iaxis) + ) def map_axis_permutation(self, expr: AxisPermutation ) -> None: @@ -335,7 +357,36 @@ def map_axis_permutation(self, expr: AxisPermutation def map_basic_index(self, expr: BasicIndex) -> None: self.rec(expr.array) - self.add_equations_using_index_lambda_version_of_expr(expr) + # FIXME: Figure out how to integrate the UseInputAxis stuff into + # add_equations_using_index_lambda_version_of_expr + # self.add_equations_using_index_lambda_version_of_expr(expr) + i_out_axis = 0 + + assert len(expr.indices) == expr.array.ndim + + for i_in_axis, idx in enumerate(expr.indices): + if isinstance(idx, int): + pass + else: + assert isinstance(idx, NormalizedSlice) + use_input_axis_tags = expr.axes[i_out_axis].tags_of_type( + UseInputAxis) + if use_input_axis_tags: + tag, = use_input_axis_tags + self.record_equation( + self.get_var_for_axis(expr.array, tag.axis), + self.get_var_for_axis(expr, i_out_axis)) + elif (idx.step == 1 + and are_shape_components_equal(idx.start, 0) + and are_shape_components_equal(idx.stop, + expr.array.shape[i_in_axis])): + + self.record_equation( + self.get_var_for_axis(expr.array, i_in_axis), + self.get_var_for_axis(expr, i_out_axis) + ) + + i_out_axis += 1 def map_contiguous_advanced_index(self, expr: AdvancedIndexInContiguousAxes diff --git a/pytato/utils.py b/pytato/utils.py index 0039a8ff4..0af2b79da 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -30,6 +30,7 @@ cast, ) +import islpy as isl from typing import (Tuple, List, Union, Callable, Any, Sequence, Dict, Optional, Iterable, TypeVar, FrozenSet) from pytato.array import (Array, ShapeType, IndexLambda, SizeParam, ShapeComponent, diff --git a/test/test_codegen.py b/test/test_codegen.py index 7f6cedb12..83be983d8 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1612,7 +1612,7 @@ def test_zero_size_cl_array_dedup(ctx_factory): x4 = pt.make_data_wrapper(x_cl2) out = pt.make_dict_of_named_arrays({"out1": 2*x1, - "out2": 3*x2, + "out2": 2*x2, "out3": x3 + x4 }) @@ -1963,7 +1963,8 @@ def build_expression(tracer): np.testing.assert_allclose(outputs[key], expected[key]) -def test_nested_function_calls(ctx_factory): +@pytest.mark.parametrize("should_concatenate_bar", (False, True)) +def test_nested_function_calls(ctx_factory, should_concatenate_bar): from functools import partial ctx = ctx_factory() @@ -1996,7 +1997,16 @@ def call_bar(tracer, x, y): result = pt.make_dict_of_named_arrays({"out1": call_bar(pt.trace_call, x1, y1), "out2": call_bar(pt.trace_call, x2, y2)} ) + result = pt.transform.Deduplicator()(result) result = pt.tag_all_calls_to_be_inlined(result) + if should_concatenate_bar: + from pytato.transform.calls import CallSiteDependencyCollector + assert len(CallSiteDependencyCollector(())(result)) == 4 + result = pt.concatenate_calls( + result, + lambda x: pt.tags.FunctionIdentifier("bar") in x.call.function.tags) + assert len(CallSiteDependencyCollector(())(result)) == 2 + expect = pt.make_dict_of_named_arrays({"out1": call_bar(ref_tracer, x1, y1), "out2": call_bar(ref_tracer, x2, y2)} ) @@ -2068,6 +2078,111 @@ def test_pow_arg_casting(ctx_factory): (float, np.float32, np.float64) +def test_concatenate_calls_no_nested(ctx_factory): + rng = np.random.default_rng(0) + + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + def foo(x, y): + return 3*x + 4*y + 42*pt.sin(x) + 1729*pt.tan(y)*pt.maximum(x, y) + + x1 = pt.make_placeholder("x1", (10, 4), np.float64) + x2 = pt.make_placeholder("x2", (10, 4), np.float64) + + y1 = pt.make_placeholder("y1", (10, 4), np.float64) + y2 = pt.make_placeholder("y2", (10, 4), np.float64) + + z1 = pt.make_placeholder("z1", (10, 4), np.float64) + z2 = pt.make_placeholder("z2", (10, 4), np.float64) + + result = pt.make_dict_of_named_arrays({"out1": 2*pt.trace_call(foo, 2*x1, 3*x2), + "out2": 4*pt.trace_call(foo, 4*y1, 9*y2), + "out3": 6*pt.trace_call(foo, 7*z1, 8*z2) + }) + result = pt.transform.Deduplicator()(result) + + concatenated_result = pt.concatenate_calls( + result, lambda x: pt.tags.FunctionIdentifier("foo") in x.call.function.tags) + + result = pt.tag_all_calls_to_be_inlined(result) + concatenated_result = pt.tag_all_calls_to_be_inlined(concatenated_result) + + assert (pt.analysis.get_num_nodes(pt.inline_calls(result)) + > pt.analysis.get_num_nodes(pt.inline_calls(concatenated_result))) + + x1_np, x2_np, y1_np, y2_np, z1_np, z2_np = rng.random((6, 10, 4)) + + _, out_dict1 = pt.generate_loopy(result)(cq, + x1=x1_np, x2=x2_np, + y1=y1_np, y2=y2_np, + z1=z1_np, z2=z2_np) + + _, out_dict2 = pt.generate_loopy(concatenated_result)(cq, + x1=x1_np, x2=x2_np, + y1=y1_np, y2=y2_np, + z1=z1_np, z2=z2_np) + assert out_dict1.keys() == out_dict2.keys() + + for key in out_dict1: + np.testing.assert_allclose(out_dict1[key], out_dict2[key]) + + +def test_concatenation_via_constant_expressions(ctx_factory): + + from pytato.transform.calls import CallSiteDependencyCollector + + rng = np.random.default_rng(0) + + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + def resampling(coords, iels): + return coords[iels] + + n_el = 1000 + n_dof = 20 + n_dim = 3 + + n_left_els = 17 + n_right_els = 29 + + coords_dofs_np = rng.random((n_el, n_dim, n_dof), np.float64) + left_bnd_iels_np = rng.integers(low=0, high=n_el, size=n_left_els) + right_bnd_iels_np = rng.integers(low=0, high=n_el, size=n_right_els) + + coords_dofs = pt.make_data_wrapper(coords_dofs_np) + left_bnd_iels = pt.make_data_wrapper(left_bnd_iels_np) + right_bnd_iels = pt.make_data_wrapper(right_bnd_iels_np) + + lcoords = pt.trace_call(resampling, coords_dofs, left_bnd_iels) + rcoords = pt.trace_call(resampling, coords_dofs, right_bnd_iels) + + result = pt.make_dict_of_named_arrays({"lcoords": lcoords, + "rcoords": rcoords}) + result = pt.transform.Deduplicator()(result) + result = pt.tag_all_calls_to_be_inlined(result) + + assert len(CallSiteDependencyCollector(())(result)) == 2 + concated_result = pt.concatenate_calls( + result, + lambda cs: pt.tags.FunctionIdentifier("resampling") in cs.call.function.tags + ) + assert len(CallSiteDependencyCollector(())(concated_result)) == 1 + + _, out_result = pt.generate_loopy(result)(cq) + np.testing.assert_allclose(out_result["lcoords"], + coords_dofs_np[left_bnd_iels_np]) + np.testing.assert_allclose(out_result["rcoords"], + coords_dofs_np[right_bnd_iels_np]) + + _, out_concated_result = pt.generate_loopy(result)(cq) + np.testing.assert_allclose(out_concated_result["lcoords"], + coords_dofs_np[left_bnd_iels_np]) + np.testing.assert_allclose(out_concated_result["rcoords"], + coords_dofs_np[right_bnd_iels_np]) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) From 3ce83ca90114f08e020d414e4786f2c90fbc920a Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Sun, 23 Mar 2025 18:29:57 -0500 Subject: [PATCH 177/178] don't assume that all ranks communicate if communication is present on some ranks --- pytato/distributed/partition.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index fcf8f899e..98f8a9027 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -808,6 +808,7 @@ def find_distributed_partition( # {{{ create (local) parts out of batch ids part_comm_ids: list[_PartCommIDs] = [] + if comm_batches: recv_ids: FrozenOrderedSet[CommunicationOpIdentifier] = FrozenOrderedSet() for batch in comm_batches: @@ -828,7 +829,8 @@ def find_distributed_partition( _PartCommIDs( recv_ids=recv_ids, send_ids=FrozenOrderedSet())) - else: + + if not part_comm_ids: part_comm_ids.append( _PartCommIDs( recv_ids=FrozenOrderedSet(), From b864835f9bd064c7677a6a9814a531fffa95daa9 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 1 Apr 2025 20:43:08 -0700 Subject: [PATCH 178/178] make concatenation deterministic --- pytato/analysis/__init__.py | 22 +++--- pytato/transform/__init__.py | 39 +++++++---- pytato/transform/calls.py | 128 +++++++++++++++++------------------ 3 files changed, 103 insertions(+), 86 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index dc505cee5..75eeb2723 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -28,7 +28,7 @@ from typing import TYPE_CHECKING, Any, Never -from orderedsets import FrozenOrderedSet +from orderedsets import FrozenOrderedSet, OrderedSet from typing_extensions import Self from loopy.tools import LoopyKeyBuilder @@ -107,6 +107,7 @@ # {{{ NUserCollector +# FIXME: Use ordered sets class NUserCollector(Mapper[None, None, []]): """ A :class:`pytato.transform.CachedWalkMapper` that records the number of @@ -250,6 +251,7 @@ def get_nusers(outputs: Array | DictOfNamedArrays) -> Mapping[Array, int]: # {{{ is_einsum_similar_to_subscript +# FIXME: Use ordered sets def _get_indices_from_input_subscript(subscript: str, is_output: bool, ) -> tuple[str, ...]: @@ -475,7 +477,7 @@ def __init__( self, count_duplicates: bool = False, traverse_functions: bool = True, - _visited_functions: set[Any] | None = None, + _visited_functions: OrderedSet[Any] | None = None, ) -> None: super().__init__(_visited_functions=_visited_functions) @@ -612,7 +614,7 @@ class NodeMultiplicityMapper(CachedWalkMapper[[]]): def __init__( self, traverse_functions: bool = True, - _visited_functions: set[Any] | None = None) -> None: + _visited_functions: OrderedSet[Any] | None = None) -> None: super().__init__(_visited_functions=_visited_functions) self.traverse_functions = traverse_functions @@ -677,7 +679,7 @@ class CallSiteCountMapper(CachedWalkMapper[[]]): The number of nodes. """ - def __init__(self, _visited_functions: set[Any] | None = None) -> None: + def __init__(self, _visited_functions: OrderedSet[Any] | None = None) -> None: super().__init__(_visited_functions=_visited_functions) self.count = 0 @@ -719,6 +721,7 @@ def get_num_call_sites(outputs: Array | DictOfNamedArrays) -> int: # {{{ TagCountMapper +# FIXME: Use ordered sets class TagCountMapper(CombineMapper[int, Never]): """ Returns the number of nodes in a DAG that are tagged with all the tag types in @@ -819,7 +822,7 @@ def __init__( super().__init__() self.collect_func = collect_func self.traverse_functions = traverse_functions - self.nodes: set[NodeT] = set() + self.nodes: OrderedSet[NodeT] = OrderedSet() def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: return expr @@ -854,7 +857,7 @@ def map_function_definition(self, expr: FunctionDefinition) -> None: def collect_nodes_of_type( outputs: Array | DictOfNamedArrays, - node_type: type[NodeT]) -> frozenset[NodeT]: + node_type: type[NodeT]) -> FrozenOrderedSet[NodeT]: """Returns the nodes that are instances of *node_type* in DAG *outputs*.""" from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) @@ -865,11 +868,11 @@ def collect_func(expr: NodeT) -> bool: nc = NodeCollector(collect_func) nc(outputs) - return frozenset(nc.nodes) + return FrozenOrderedSet(nc.nodes) def collect_materialized_nodes( - outputs: Array | DictOfNamedArrays) -> frozenset[NodeT]: + outputs: Array | DictOfNamedArrays) -> FrozenOrderedSet[NodeT]: from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) @@ -880,13 +883,14 @@ def collect_func(expr: NodeT) -> bool: nc = NodeCollector(collect_func) nc(outputs) - return frozenset(nc.nodes) + return FrozenOrderedSet(nc.nodes) # }}} # {{{ DependencyTracer +# FIXME: Use ordered sets class DependencyTracer(CombineMapper[frozenset[tuple[Array, ...]], Never]): """ Maps a DAG and a node to a :class:`frozenset` of `tuple`\\ s of diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 04e373f22..9a1e1437b 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -43,6 +43,7 @@ import numpy as np from immutabledict import immutabledict +from orderedsets import FrozenOrderedSet, OrderedSet from typing_extensions import Self from pymbolic.mapper.optimize import optimize_mapper @@ -1523,6 +1524,7 @@ def map_named_call_result(self, expr: NamedCallResult) -> ResultT: # {{{ DependencyMapper +# FIXME: Change to ordered sets (including R) class DependencyMapper(CombineMapper[R, Never]): """ Maps a :class:`pytato.array.Array` to a :class:`frozenset` of @@ -1601,6 +1603,7 @@ def clone_for_callee(self, function: FunctionDefinition) -> Self: # {{{ SubsetDependencyMapper +# FIXME: Change to ordered sets class SubsetDependencyMapper(DependencyMapper): """ Mapper to combine the dependencies of an expression that are a subset of @@ -1621,6 +1624,7 @@ def combine(self, *args: frozenset[Array]) -> frozenset[Array]: # {{{ InputGatherer +# FIXME: Change to ordered sets class InputGatherer( CombineMapper[frozenset[InputArgumentBase], frozenset[InputArgumentBase]]): """ @@ -1672,6 +1676,7 @@ def map_call(self, expr: Call) -> frozenset[InputArgumentBase]: # {{{ precompute_subexpressions +# FIXME: Change to ordered sets # FIXME: Think about what happens when subexpressions contain outlined functions class _PrecomputableSubexpressionGatherer( CombineMapper[frozenset[Array], frozenset[Array]]): @@ -1802,6 +1807,7 @@ def precompute_subexpressions( # {{{ SizeParamGatherer +# FIXME: Change to ordered sets class SizeParamGatherer( CombineMapper[frozenset[SizeParam], frozenset[SizeParam]]): """ @@ -2053,13 +2059,13 @@ class CachedWalkMapper(WalkMapper[P]): def __init__( self, - _visited_functions: set[VisitKeyT] | None = None + _visited_functions: OrderedSet[VisitKeyT] | None = None ) -> None: super().__init__() - self._visited_arrays_or_names: set[VisitKeyT] = set() + self._visited_arrays_or_names: OrderedSet[VisitKeyT] = OrderedSet() - self._visited_functions: set[VisitKeyT] = \ - _visited_functions if _visited_functions is not None else set() + self._visited_functions: OrderedSet[VisitKeyT] = \ + _visited_functions if _visited_functions is not None else OrderedSet() def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs @@ -2110,7 +2116,7 @@ class TopoSortMapper(CachedWalkMapper[[]]): def __init__( self, - _visited_functions: set[VisitKeyT] | None = None) -> None: + _visited_functions: OrderedSet[VisitKeyT] | None = None) -> None: super().__init__(_visited_functions=_visited_functions) self.topological_order: list[Array] = [] @@ -2247,17 +2253,17 @@ def _materialize_if_mpms(expr: Array, """ from functools import reduce - materialized_predecessors: frozenset[Array] = reduce( - frozenset.union, + materialized_predecessors: FrozenOrderedSet[Array] = reduce( + FrozenOrderedSet.union, (pred.materialized_predecessors for pred in predecessors), - frozenset()) + FrozenOrderedSet()) if nsuccessors > 1 and len(materialized_predecessors) > 1: if not expr.tags_of_type(ImplStored): new_expr = expr.tagged(ImplStored()) else: new_expr = expr - return MPMSMaterializerAccumulator(frozenset([new_expr]), new_expr) + return MPMSMaterializerAccumulator(FrozenOrderedSet([new_expr]), new_expr) else: return MPMSMaterializerAccumulator(materialized_predecessors, expr) @@ -2309,7 +2315,7 @@ def clone_for_callee( def _map_input_base(self, expr: InputArgumentBase ) -> MPMSMaterializerAccumulator: - return MPMSMaterializerAccumulator(frozenset([expr]), expr) + return MPMSMaterializerAccumulator(FrozenOrderedSet([expr]), expr) map_placeholder = _map_input_base map_data_wrapper = _map_input_base @@ -2327,7 +2333,9 @@ def map_index_lambda(self, expr: IndexLambda) -> MPMSMaterializerAccumulator: for bnd_name, bnd in sorted(children_rec.items())}) if ( - frozenset(new_children.keys()) == frozenset(expr.bindings.keys()) + ( + FrozenOrderedSet(new_children.keys()) + == FrozenOrderedSet(expr.bindings.keys())) and all( new_children[name] is expr.bindings[name] for name in expr.bindings)): @@ -2475,7 +2483,7 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays def map_loopy_call_result(self, expr: NamedArray) -> MPMSMaterializerAccumulator: # loopy call result is always materialized - return MPMSMaterializerAccumulator(frozenset([expr]), expr) + return MPMSMaterializerAccumulator(FrozenOrderedSet([expr]), expr) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder @@ -2500,7 +2508,7 @@ def map_distributed_send_ref_holder(self, def map_distributed_recv(self, expr: DistributedRecv ) -> MPMSMaterializerAccumulator: - return MPMSMaterializerAccumulator(frozenset([expr]), expr) + return MPMSMaterializerAccumulator(FrozenOrderedSet([expr]), expr) def map_named_call_result(self, expr: NamedCallResult ) -> MPMSMaterializerAccumulator: @@ -2530,6 +2538,7 @@ def copy_dict_of_named_arrays(source_dict: DictOfNamedArrays, return DictOfNamedArrays(data, tags=source_dict.tags) +# FIXME: Use ordered sets def get_dependencies(expr: DictOfNamedArrays) -> dict[str, frozenset[Array]]: """Returns the dependencies of each named array in *expr*. """ @@ -2624,6 +2633,7 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: # {{{ UsersCollector +# FIXME: Use ordered sets class UsersCollector(CachedMapper[None, Never, []]): """ Maps a graph to a dictionary representation mapping a node to its users, @@ -2765,6 +2775,7 @@ def map_named_call_result(self, expr: NamedCallResult) -> None: self.rec(expr._container) +# FIXME: Use ordered sets def get_users(expr: ArrayOrNames) -> dict[ArrayOrNames, set[ArrayOrNames]]: """ @@ -2779,6 +2790,7 @@ def get_users(expr: ArrayOrNames) -> dict[ArrayOrNames, # {{{ operations on graphs in dict form +# FIXME: Use ordered sets def _recursively_get_all_users( direct_users: Mapping[ArrayOrNames, set[ArrayOrNames]], node: ArrayOrNames) -> frozenset[ArrayOrNames]: @@ -2803,6 +2815,7 @@ def _recursively_get_all_users( return frozenset(result) +# FIXME: Use ordered sets def rec_get_user_nodes(expr: ArrayOrNames, node: ArrayOrNames, ) -> frozenset[ArrayOrNames]: diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index f3e68302b..93f71760f 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -44,12 +44,10 @@ Callable, Collection, Dict, - FrozenSet, Generator, List, Never, Sequence, - Set, Tuple, cast, ) @@ -57,6 +55,7 @@ import attrs from immutabledict import immutabledict +from orderedsets import FrozenOrderedSet, OrderedSet import pymbolic.primitives as prim from pytools import memoize_method, memoize_on_first_arg @@ -244,12 +243,12 @@ def __init__( self, _fn_input_gatherers: dict[FunctionDefinition, InputGatherer] | None = None, - _visited_functions: set[Any] | None = None + _visited_functions: OrderedSet[Any] | None = None ) -> None: if _fn_input_gatherers is None: _fn_input_gatherers = {} - self.call_to_used_inputs: dict[Call, set[Placeholder]] = {} + self.call_to_used_inputs: dict[Call, OrderedSet[Placeholder]] = {} self._fn_input_gatherers = _fn_input_gatherers super().__init__(_visited_functions=_visited_functions) @@ -277,14 +276,14 @@ def map_named_call_result( input_gatherer = InputGatherer() self._fn_input_gatherers[call.function] = input_gatherer - used_inputs = self.call_to_used_inputs.setdefault(call, set()) + used_inputs = self.call_to_used_inputs.setdefault(call, OrderedSet()) used_inputs |= input_gatherer(call.function.returns[expr.name]) super().map_named_call_result(expr) def _collect_used_call_inputs( - expr: ArrayOrNames) -> immutabledict[Call, frozenset[Placeholder]]: + expr: ArrayOrNames) -> immutabledict[Call, FrozenOrderedSet[Placeholder]]: """ Returns a mapping from :class:`~pytato.function.Call` to the set of input :class:`~pt.array.Placeholder`\ s belonging to its function definition that are @@ -296,7 +295,7 @@ def _collect_used_call_inputs( collector(expr) return immutabledict({ - call: frozenset(inputs) + call: FrozenOrderedSet(inputs) for call, inputs in collector.call_to_used_inputs.items()}) # }}} @@ -311,7 +310,7 @@ class _UnusedCallBindingZeroer(CopyMapper): """ def __init__( self, - call_to_used_inputs: Mapping[Call, frozenset[Placeholder]], + call_to_used_inputs: Mapping[Call, FrozenOrderedSet[Placeholder]], _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: @@ -413,7 +412,7 @@ class CallSiteLocation: class CallSiteDependencyCollector( - CombineMapper[FrozenSet[CallSiteLocation], Never]): + CombineMapper[FrozenOrderedSet[CallSiteLocation], Never]): r""" Collects all the call sites in a :mod:`pytato` expression along with their interdependencies. @@ -437,14 +436,14 @@ def __init__(self, stack: Tuple[Call, ...]) -> None: Dict[CallSiteLocation, CallSiteLocation] = {} super().__init__() - def combine(self, *args: FrozenSet[CallSiteLocation] - ) -> FrozenSet[CallSiteLocation]: - return reduce(lambda a, b: a | b, args, frozenset()) + def combine(self, *args: FrozenOrderedSet[CallSiteLocation] + ) -> FrozenOrderedSet[CallSiteLocation]: + return reduce(lambda a, b: a | b, args, FrozenOrderedSet()) - def map_size_param(self, expr: SizeParam) -> FrozenSet[CallSiteLocation]: - return frozenset() + def map_size_param(self, expr: SizeParam) -> FrozenOrderedSet[CallSiteLocation]: + return FrozenOrderedSet() - def map_call(self, expr: Call) -> FrozenSet[CallSiteLocation]: + def map_call(self, expr: Call) -> FrozenOrderedSet[CallSiteLocation]: cs = CallSiteLocation(expr, self.stack) new_mapper_for_fn = CallSiteDependencyCollector(stack=self.stack + (expr,)) @@ -459,7 +458,7 @@ def map_call(self, expr: Call) -> FrozenSet[CallSiteLocation]: self.call_site_to_dep_call_sites.update( new_mapper_for_fn.call_site_to_dep_call_sites) - return self.combine(frozenset([cs]), dependent_call_sites) + return self.combine(FrozenOrderedSet([cs]), dependent_call_sites) class _NamedCallResultReplacerPostConcatenate(CopyMapper): @@ -577,7 +576,7 @@ class _InputConcatabilityGetterAcc: .. attribute:: seen_inputs - A :class:`frozenset` of all :class:`pytato.InputArgumentBase` + A :class:`FrozenOrderedSet` of all :class:`pytato.InputArgumentBase` predecessors of a node. .. attribute:: input_concatability @@ -592,13 +591,13 @@ class _InputConcatabilityGetterAcc: concatenation cannot be performed along those axes for the mapped array. """ - seen_inputs: FrozenSet[InputArgumentBase] + seen_inputs: FrozenOrderedSet[InputArgumentBase] input_concatability: Mapping[Concatenatability, Mapping[InputArgumentBase, Concatenatability]] def __post_init__(self) -> None: assert all( - frozenset(input_concat.keys()) == self.seen_inputs + FrozenOrderedSet(input_concat.keys()) == self.seen_inputs for input_concat in self.input_concatability.values()) __attrs_post_init__ = __post_init__ @@ -732,10 +731,10 @@ def _combine_input_accs( input_concatabilities: Dict[Concatenatability, Mapping[InputArgumentBase, Concatenatability]] = {} - seen_inputs: FrozenSet[InputArgumentBase] = reduce( - frozenset.union, + seen_inputs: FrozenOrderedSet[InputArgumentBase] = reduce( + FrozenOrderedSet.union, (operand_acc.seen_inputs for operand_acc in operand_accs), - frozenset()) + FrozenOrderedSet()) # The core logic here is to filter the iaxis in out_axis_to_operand_axes # so that all the operands agree on how the input arguments must be @@ -831,18 +830,18 @@ def _combine_named_result_accs_simple( valid_concatenatabilities: List[FunctionConcatenability] = [] input_args = reduce( - frozenset.union, + FrozenOrderedSet.union, [ acc.seen_inputs for acc in named_result_accs.values()], - frozenset()) + FrozenOrderedSet()) candidate_concat_axes = reduce( - frozenset.union, + FrozenOrderedSet.union, [ - frozenset(acc.input_concatability.keys()) + FrozenOrderedSet(acc.input_concatability.keys()) for acc in named_result_accs.values()], - frozenset()) + FrozenOrderedSet()) # print(f"{candidate_concat_axes=}") @@ -946,7 +945,7 @@ def _map_input_arg_base( input_concatenatability[ConcatableIfConstant()] = immutabledict( {expr: ConcatableIfConstant()}) - return _InputConcatabilityGetterAcc(frozenset([expr]), + return _InputConcatabilityGetterAcc(FrozenOrderedSet([expr]), immutabledict(input_concatenatability)) map_placeholder = _map_input_arg_base @@ -1035,10 +1034,10 @@ def map_named_call_result( valid_concatenatabilities = _get_valid_concatenatability_constraints_simple( expr._container.function) - expr_concat_possibilities = { + expr_concat_possibilities = FrozenOrderedSet( valid_concatenability.output_to_concatenatability[expr.name] for valid_concatenability in valid_concatenatabilities - } + ) input_concatenatabilities: Dict[Concatenatability, Mapping[InputArgumentBase, @@ -1046,7 +1045,7 @@ def map_named_call_result( rec_bindings = {bnd_name: self.rec(binding) for bnd_name, binding in expr._container.bindings.items()} callee_acc = self.rec(expr._container.function.returns[expr.name]) - seen_inputs: Set[InputArgumentBase] = set() + seen_inputs: OrderedSet[InputArgumentBase] = OrderedSet() for seen_input in callee_acc.seen_inputs: if isinstance(seen_input, Placeholder): @@ -1099,7 +1098,7 @@ def map_named_call_result( input_concatenatabilities[concat_possibility] = immutabledict( caller_input_concatabilities) - return _InputConcatabilityGetterAcc(frozenset(seen_inputs), + return _InputConcatabilityGetterAcc(FrozenOrderedSet(seen_inputs), immutabledict(input_concatenatabilities)) def map_loopy_call_result( @@ -1154,11 +1153,11 @@ class _ConcatabilityCollector(CachedWalkMapper): def __init__( self, current_stack: Tuple[Call, ...], - _visited_functions: set[Any] | None = None + _visited_functions: OrderedSet[Any] | None = None ) -> None: self.ary_to_concatenatability: Dict[ArrayOnStackT, Concatenatability] = {} self.current_stack = current_stack - self.call_sites_on_hold: Set[Call] = set() + self.call_sites_on_hold: OrderedSet[Call] = OrderedSet() super().__init__(_visited_functions=_visited_functions) # type-ignore-reason: CachedWalkMaper takes variadic `*args, **kwargs`. @@ -1293,7 +1292,7 @@ def map_call(self, # type: ignore[override] for named_result in expr.values()): self.call_sites_on_hold.add(expr) else: - self.call_sites_on_hold -= {expr} + self.call_sites_on_hold.remove(expr) # FIXME The code below bypasses caching of function definitions new_mapper = self.clone_with_new_call_on_stack(expr) for name, val_in_callee in expr.function.returns.items(): @@ -1890,7 +1889,7 @@ def _get_ary_to_concatenatabilities(call_sites: Sequence[Call], def _get_replacement_map_post_concatenating( call_sites: Sequence[Call], - used_call_results: frozenset(NamedCallResult), + used_call_results: FrozenOrderedSet(NamedCallResult), input_concatenator: _InputConcatenator, output_slicer: _OutputSlicer) -> Mapping[NamedCallResult, Array]: """ @@ -1940,21 +1939,21 @@ def _get_replacement_map_post_concatenating( # FIXME: We may be able to handle this without burdening the user # See https://github.com/inducer/pytato/issues/559 from collections import defaultdict - param_to_used_calls = defaultdict(set) + param_to_used_calls = defaultdict(OrderedSet) for output_name in template_call_site.keys(): for csite in call_sites: call_result = csite[output_name] if call_result in used_call_results: ret = csite.function.returns[output_name] used_params = ( - { + OrderedSet( expr.name - for expr in InputGatherer()(ret)} + for expr in InputGatherer()(ret)) & csite.function.parameters) for name in used_params: - param_to_used_calls[name] |= {csite} + param_to_used_calls[name].add(csite) for name, used_calls in param_to_used_calls.items(): - if used_calls != set(call_sites): + if used_calls != OrderedSet(call_sites): from warnings import warn warn( f"DAG output does not depend on parameter '{name}' for some " @@ -2065,13 +2064,13 @@ def concatenate_calls(expr: ArrayOrNames, call_site_collector = CallSiteDependencyCollector(stack=()) all_call_sites = call_site_collector(expr) - filtered_call_sites = {cs + filtered_call_sites = FrozenOrderedSet(cs for cs in all_call_sites - if call_site_filter(cs)} + if call_site_filter(cs)) - function_ids = { + function_ids = FrozenOrderedSet( next(iter(cs.call.function.tags_of_type(FunctionIdentifier))) - for cs in filtered_call_sites} + for cs in filtered_call_sites) # Input concatenator needs to be set up outside of the loop in order to prevent # creating duplicates; probably not strictly necessary for output slicer @@ -2087,9 +2086,9 @@ def concatenate_calls(expr: ArrayOrNames, call_site_to_dep_call_sites = \ call_site_dep_collector.call_site_to_dep_call_sites - unbatched_call_sites: Set[CallSiteLocation] = { + unbatched_call_sites: OrderedSet[CallSiteLocation] = OrderedSet( cs for cs in call_site_to_dep_call_sites.keys() - if call_site_filter(cs) and fid in cs.call.function.tags} + if call_site_filter(cs) and fid in cs.call.function.tags) for cs in unbatched_call_sites: for ret in cs.call.function.returns.values(): @@ -2098,7 +2097,7 @@ def concatenate_calls(expr: ArrayOrNames, raise NotImplementedError( "Concatenation of nested calls is not yet supported.") - call_site_batches: List[FrozenSet[CallSiteLocation]] = [] + call_site_batches: List[FrozenOrderedSet[CallSiteLocation]] = [] replacement_map: Dict[ Tuple[NamedCallResult, Tuple[Call, ...]], @@ -2107,9 +2106,9 @@ def concatenate_calls(expr: ArrayOrNames, used_call_results = collect_nodes_of_type(result, NamedCallResult) while unbatched_call_sites: - ready_call_sites = frozenset({ + ready_call_sites = FrozenOrderedSet( cs for cs in unbatched_call_sites - if not call_site_to_dep_call_sites[cs] & unbatched_call_sites}) + if not call_site_to_dep_call_sites[cs] & unbatched_call_sites) from mpi4py import MPI rank = MPI.COMM_WORLD.rank @@ -2147,18 +2146,18 @@ def concatenate_calls(expr: ArrayOrNames, # # nnodes_other = get_num_nodes(cs.call.function.returns[name]) # # print(f"{rank}: {name=}, {nnodes_template=}, {nnodes_other=}") - similar_call_sites = frozenset({ + similar_call_sites = FrozenOrderedSet( cs for cs in ready_call_sites if ( ( - frozenset(cs.call.function.returns.keys()) - == frozenset(template_fn.returns.keys())) + FrozenOrderedSet(cs.call.function.returns.keys()) + == FrozenOrderedSet(template_fn.returns.keys())) and all( similarity_comparer( cs.call.function.returns[name], template_fn.returns[name]) for name in template_fn.returns) - and cs.stack == template_call_site.stack)}) + and cs.stack == template_call_site.stack)) # if fid.identifier == "_make_fluid_state": # print(f"{rank}: {len(similar_call_sites)=}") @@ -2166,18 +2165,19 @@ def concatenate_calls(expr: ArrayOrNames, if not similar_call_sites: raise ValueError("Failed to find similar call sites to concatenate.") - def get_axis0_len(cs): - first_out_name = next(iter(cs.call.function.returns.keys())) - axis0_len = cs.call[first_out_name].shape[0] - assert all( - cs.call[name].shape[0] == axis0_len - for name in cs.call.function.returns) - return axis0_len + # def get_axis0_len(cs): + # first_out_name = next(iter(cs.call.function.returns.keys())) + # axis0_len = cs.call[first_out_name].shape[0] + # assert all( + # cs.call[name].shape[0] == axis0_len + # for name in cs.call.function.returns) + # return axis0_len - batch_call_sites = sorted(similar_call_sites, key=get_axis0_len) + # batch_call_sites = FrozenOrderedSet(sorted(similar_call_sites, key=get_axis0_len)) + batch_call_sites = similar_call_sites call_site_batches.append(batch_call_sites) - unbatched_call_sites -= frozenset(batch_call_sites) + unbatched_call_sites -= batch_call_sites # FIXME: this doesn't work; need to create/execute batches one at a time, # then repeat the steps above to collect the updated call sites after @@ -2218,7 +2218,7 @@ def get_axis0_len(cs): input_concatenator=input_concatenator, output_slicer=output_slicer) - stack, = {cs.stack for cs in call_sites} + stack, = FrozenOrderedSet(cs.stack for cs in call_sites) replacement_map.update({ (old_expr, stack): new_expr