Skip to content

New branch for TP axis tag testing #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 161 commits into
base: production
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
161 commits
Select commit Hold shift + click to select a range
f5c377e
Add tag to store array creation traceback
matthiasdiener Mar 23, 2022
4c32cb6
don't make it a unique tag
matthiasdiener Mar 23, 2022
56fbf4c
adds a common _get_default_tags
kaushikcfd Mar 23, 2022
8fee6d4
Change back to UniqueTag
matthiasdiener Mar 23, 2022
da3e962
Merge branch 'default_tags' into tag_created_at
matthiasdiener Mar 23, 2022
bdff59c
use _get_default_tags
matthiasdiener Mar 23, 2022
6d18144
store a tupleized StackSummary
matthiasdiener Mar 24, 2022
30dfc49
Merge branch 'main' into tag_created_at
matthiasdiener Mar 24, 2022
36e720b
Merge branch 'main' into tag_created_at
matthiasdiener Mar 24, 2022
2c3faeb
work around mypy
matthiasdiener Mar 24, 2022
739f3d3
more line fixes
matthiasdiener Mar 24, 2022
897e39b
Merge branch 'main' into tag_created_at
matthiasdiener Mar 25, 2022
3ec977f
Merge branch 'main' into tag_created_at
matthiasdiener Mar 25, 2022
fdb2906
Merge branch 'main' into tag_created_at
matthiasdiener Mar 28, 2022
4fd3d64
use a class for the traceback instead of tuples
matthiasdiener Mar 28, 2022
7702550
also test to_stacksummary
matthiasdiener Mar 28, 2022
7a86557
flake8
matthiasdiener Mar 28, 2022
31bcda8
Add remove_tags_of_type
inducer Mar 28, 2022
5c32222
test_array_dot_repr: Remove CreatedAt tags before comparing
inducer Mar 28, 2022
2d0358f
only add CreatedAt in debug mode
matthiasdiener Mar 28, 2022
1c275fd
restructure test_created_at
matthiasdiener Mar 28, 2022
02362bf
make _PytatoStackSummary a dataclass
matthiasdiener Mar 28, 2022
7b1f7b8
add __repr__
matthiasdiener Mar 29, 2022
07b2fa1
fix 2 tests
matthiasdiener Mar 29, 2022
437954c
illustrate test failure with construct_intestine_graph
matthiasdiener Mar 29, 2022
f05592e
shorten traceback printing
matthiasdiener Mar 29, 2022
d040996
use separate field for CreatedAt
matthiasdiener Mar 30, 2022
91d0929
Merge branch 'main' into tag_created_at
matthiasdiener Mar 30, 2022
9fdd602
fix tests
matthiasdiener Mar 30, 2022
e606a48
fix doctest
matthiasdiener Mar 30, 2022
235d9a7
make it a tag again
matthiasdiener Mar 30, 2022
fc9a873
Merge branch 'main' into tag_created_at
matthiasdiener Mar 31, 2022
d80066f
use tooltip instead of table row
matthiasdiener Apr 1, 2022
ef3339f
force openmpi usage
matthiasdiener Apr 1, 2022
1ff1a2b
check for existing CreatedAt and make it a UniqueTag again
matthiasdiener Apr 1, 2022
f509f41
Merge branch 'main' into tag_created_at
matthiasdiener Apr 2, 2022
73d6ab6
Merge branch 'main' into tag_created_at
matthiasdiener Apr 4, 2022
638e6e6
Merge branch 'main' into tag_created_at
matthiasdiener May 13, 2022
c57a4a1
flake8
matthiasdiener May 16, 2022
4ae31b1
add simple equality test
matthiasdiener May 16, 2022
f559a59
lint fixes
matthiasdiener May 16, 2022
0794bdb
add InfoTag class and filter tags based on it
matthiasdiener May 16, 2022
43c83ec
fix doc
matthiasdiener May 16, 2022
c09bbf3
another doc fix
matthiasdiener May 16, 2022
cd67d68
use IgnoredForEqualityTag
matthiasdiener May 17, 2022
71dd791
UNDO BEFORE MERGE: use external project branches
matthiasdiener May 17, 2022
b4f8b82
Revert "UNDO BEFORE MERGE: use external project branches"
matthiasdiener May 18, 2022
bfb22ba
Revert "use IgnoredForEqualityTag"
matthiasdiener May 18, 2022
99ff0cd
rename InfoTag -> IgnoredForEqualityTag
matthiasdiener May 18, 2022
a818694
more stringent tests
matthiasdiener May 18, 2022
8a0a773
undo unnecessary test changes
matthiasdiener May 18, 2022
91fe92f
Revert "Revert "use IgnoredForEqualityTag""
matthiasdiener May 19, 2022
1111b79
simplify condition
matthiasdiener May 19, 2022
4d5e5e2
Merge branch 'main' into tag_created_at
matthiasdiener May 19, 2022
ff2582f
Revert "simplify condition"
matthiasdiener May 19, 2022
26c3590
bump pytools version + a few spelling fixes
matthiasdiener May 21, 2022
423e3fb
remove duplicated self.axes in hash()
matthiasdiener May 21, 2022
a3e6120
Merge branch 'main' into tag_created_at
matthiasdiener Jun 1, 2022
87606ed
use Taggable{__eq__,__hash__}
matthiasdiener Jun 2, 2022
73c7f77
add another test
matthiasdiener Jun 4, 2022
b84b66e
add vis test
matthiasdiener Jun 4, 2022
26824fa
Merge branch 'main' into tag_created_at
matthiasdiener Jun 20, 2022
6c653bf
make _PytatoFrameSummary, _PytatoStackSummary undocumented
matthiasdiener Jun 20, 2022
02dd6f5
use Taggable.__hash__ for tags in Array.__hash__
matthiasdiener Jun 20, 2022
05b5383
Merge branch 'main' into tag_created_at
matthiasdiener Jun 23, 2022
38d475e
Merge branch 'main' into tag_created_at
matthiasdiener Aug 1, 2022
6fe4129
Merge branch 'main' into tag_created_at
matthiasdiener Aug 26, 2022
88887bf
Merge branch 'main' into tag_created_at
matthiasdiener Oct 4, 2022
15e78d6
Merge branch 'main' into tag_created_at
matthiasdiener Oct 20, 2022
3f084ba
Merge branch 'main' into tag_created_at
matthiasdiener Mar 24, 2023
7c93707
change dataclass to attrs
matthiasdiener Mar 28, 2023
dd9916b
flake8
matthiasdiener Mar 28, 2023
61c029c
Taggable.__eq__
matthiasdiener Mar 28, 2023
3cf1559
add Array.tagged()
matthiasdiener Mar 29, 2023
9d528c3
Merge branch 'main' into tag_created_at
matthiasdiener Apr 28, 2023
f9251d2
Merge branch 'main' into tag_created_at
matthiasdiener May 19, 2023
44d1c34
restrict to DEBUG_ENABLED
matthiasdiener May 19, 2023
a150c79
force DEBUG_ENABLED for test
matthiasdiener May 19, 2023
0769067
CHERRY-PICK: Preserve High-Level Info in the Pymbolic expressions
kaushikcfd Nov 17, 2021
8b3a13b
[CHERRY-PICK]: Call BranchMorpher after dw deduplication
kaushikcfd May 26, 2022
945a147
Merge branch 'main' into production-mrgup
MTCam Jun 27, 2023
818aec4
Merge branch 'main' into production-pilot
MTCam Jul 19, 2023
f4123b4
Update to inducer@main
MTCam Jul 25, 2023
afe340f
Merge branch 'main' into production-pilot
MTCam Jul 27, 2023
c0ea704
Merge branch 'main' into production-pilot
MTCam Jul 28, 2023
3d8ba63
Merge branch 'main' into production-pilot
MTCam Jul 31, 2023
1435a00
Merge remote-tracking branch 'origin/production-pilot' into productio…
MTCam Jul 31, 2023
7bbafe4
Merge branch 'main' into production-pilot
MTCam Aug 3, 2023
8665140
Merge branch 'main' into production-pilot
MTCam Aug 4, 2023
8e20d15
Define __attrs_post_init__ only if __debug__, for all Array classes
inducer Aug 4, 2023
9a4e01d
Merge remote-tracking branch 'inducer/attrs-post-init-only-if-debug' …
MTCam Aug 4, 2023
229f0f2
Merge branch 'main' into production-pilot
MTCam Aug 9, 2023
efcae65
First shot at implementing 'F' ordered array reshapes
a-alveyblanc Sep 11, 2023
cc8f07f
Resolve merge conflicts
a-alveyblanc Sep 11, 2023
35c6d1f
Remove restriction on reshape order
a-alveyblanc Sep 11, 2023
e8e11ce
Merge branch 'main' into production-pilot
MTCam Sep 18, 2023
a8ae5e2
Merge branch 'inducer:main' into implement-f-ordered-reshapes
a-alveyblanc Oct 8, 2023
351bb6f
Merge branch 'main' into production-pilot
MTCam Oct 10, 2023
ad8ff90
Merge branch 'main' into tag_created_at
matthiasdiener Oct 14, 2023
86233c6
work around mypy/attrs issue
matthiasdiener Oct 14, 2023
4b2d2cf
Merge branch 'attrs-mypy' into tag_created_at
matthiasdiener Oct 14, 2023
3b6fdad
fix for fields
matthiasdiener Oct 14, 2023
f6b8d98
Merge remote-tracking branch 'addison/implement-f-ordered-reshapes' i…
MTCam Oct 24, 2023
148527f
Merge branch 'main' into production-pilot
MTCam Oct 26, 2023
a64ea5a
Merge branch 'main' into production-pilot
MTCam Nov 1, 2023
8a390a5
Update comments a little
MTCam Nov 1, 2023
87efc4f
Merge branch 'main' into production-pilot
MTCam Nov 2, 2023
8dc06cf
Merge branch 'main' into production-pilot
MTCam Nov 8, 2023
870849a
attempt to fix tag issue
MTCam Nov 9, 2023
060f864
number_distributed_tags: non-set, non-sorted numbering
matthiasdiener Nov 9, 2023
65d0142
make the test a bit more difficult
matthiasdiener Nov 9, 2023
41a6998
Merge remote-tracking branch 'inducer/deterministic-mpi_tag-v2' into …
MTCam Nov 9, 2023
90954ea
Merge branch 'main' into production-pilot
MTCam Nov 14, 2023
9244899
Merge branch 'main' into tag_created_at
matthiasdiener Nov 14, 2023
3ebfcfd
undo mypy ignores
matthiasdiener Nov 14, 2023
eb1c052
rewrite to use a new field in Array, non_equality_tags
matthiasdiener Nov 14, 2023
a5cec50
misc fixes
matthiasdiener Nov 14, 2023
d9898c9
undo some unecessary changes
matthiasdiener Nov 15, 2023
c5c8920
more misc fixes
matthiasdiener Nov 15, 2023
4ec3cbf
copymapper, tests
matthiasdiener Nov 15, 2023
176595d
explicitly enable/disable traceback
matthiasdiener Nov 17, 2023
40557e9
add hash test
matthiasdiener Nov 17, 2023
5240495
undo more unnecessary changes
matthiasdiener Nov 17, 2023
6e047f4
Merge branch 'main' into tag_created_at
matthiasdiener Nov 17, 2023
48b1723
Merge branch 'main' into tag_created_at
matthiasdiener Nov 21, 2023
9338f0b
more lint fixes
matthiasdiener Nov 21, 2023
f5cb92f
run all examples, fix demo_distributed_node_duplication
matthiasdiener Nov 21, 2023
36166c6
enable CreatedAt for distributed nodes
matthiasdiener Nov 21, 2023
d110d0f
Merge branch 'main' into tag_created_at
matthiasdiener Nov 22, 2023
65f7317
Merge branch 'main' into tag_created_at
matthiasdiener Nov 26, 2023
c377937
Merge branch 'main' into tag_created_at
matthiasdiener Nov 28, 2023
ea93dc1
Merge branch 'main' into production-pilot
MTCam Nov 29, 2023
4c3b06a
undo MPI tag ordering
matthiasdiener Nov 29, 2023
db9f5c9
Merge branch 'main' into production-pilot
MTCam Jan 8, 2024
50bea3e
Merge branch 'main' into tag_created_at
matthiasdiener Jan 18, 2024
708114f
Merge branch 'production' into merge-addison-with-production
MTCam Feb 2, 2024
06503b1
get precise traceback of array creation
majosm Feb 2, 2024
ab87fbf
partialmethod doesn't introduce a stack frame
majosm Feb 2, 2024
d17db17
Merge branch 'main' into tag_created_at
matthiasdiener Feb 6, 2024
c30a320
Merge branch 'main' into production-pilot
MTCam Feb 6, 2024
d8df5f8
add support for make_distributed_send_ref_holder
matthiasdiener Feb 6, 2024
dee201f
Merge branch 'main' into production-pilot
MTCam Feb 7, 2024
dd7b288
Merge branch 'main' into tag_created_at
matthiasdiener Feb 7, 2024
49be05a
Merge branch 'main' into production-pilot
MTCam Feb 13, 2024
f0d52aa
Merge remote-tracking branch 'inducer/tag_created_at' into pytato-arr…
MTCam Feb 13, 2024
a3feae3
Merge branch 'main' into tag_created_at
matthiasdiener Feb 13, 2024
e1b9181
add to MPMSMaterializer
matthiasdiener Feb 16, 2024
74965e4
Merge remote-tracking branch 'inducer/tag_created_at' into pytato-arr…
MTCam Feb 16, 2024
be9dcdd
Spew array tracing to stdout.
MTCam Mar 2, 2024
c04d053
Merge branch 'main' into production-pilot
MTCam Mar 2, 2024
ba74f02
Merge branch 'main' into production-pilot
MTCam Mar 6, 2024
eade18d
Merge branch 'production' into array-tracing
MTCam Mar 7, 2024
3856ab7
Merge remote-tracking branch 'majosm/tag_created_at-precise-tb' into …
MTCam Mar 7, 2024
ad0aa4c
Get precise traceback of array creation (#480)
majosm Mar 7, 2024
655db9a
Merge branch 'main' into tag_created_at
matthiasdiener Mar 7, 2024
331dff1
Merge remote-tracking branch 'inducer/tag_created_at' into precise-ar…
MTCam Mar 7, 2024
6e879ad
Merge branch 'main' into production-pilot
MTCam Mar 24, 2024
0f5680b
Merge with inducer/main
MTCam Apr 4, 2024
ab5728e
Disable assert non_equality_tag
MTCam Apr 11, 2024
2b3eed8
Merge branch 'main' into production-pilot
MTCam Apr 15, 2024
33ff862
New branch for axis tag testing. Up to date this time :)
a-alveyblanc Apr 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 27 additions & 19 deletions pytato/array.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from traceback import FrameSummary, StackSummary

__copyright__ = """
Copyright (C) 2020 Andreas Kloeckner
Expand Down Expand Up @@ -593,25 +594,37 @@ def _unary_op(self, op: Any) -> Array:
non_equality_tags=_get_created_at_tag(),
var_to_reduction_descr=immutabledict())

__mul__ = partialmethod(_binary_op, operator.mul)
__rmul__ = partialmethod(_binary_op, operator.mul, reverse=True)
# NOTE: Initializing the expression to "prim.Product(expr1, expr2)" is
# essential as opposed to performing "expr1 * expr2". This is to account
# for pymbolic's implementation of the "*" operator which might not
# instantiate the node corresponding to the operation when one of
# the operands is the neutral element of the operation.
#
# For the same reason 'prim.(Sum|FloorDiv|Quotient)' is preferred over the
# python operators on the operands.

__add__ = partialmethod(_binary_op, operator.add)
__radd__ = partialmethod(_binary_op, operator.add, reverse=True)
__mul__ = partialmethod(_binary_op, lambda l, r: prim.Product((l, r)))
__rmul__ = partialmethod(_binary_op, lambda l, r: prim.Product((l, r)),
reverse=True)

__sub__ = partialmethod(_binary_op, operator.sub)
__rsub__ = partialmethod(_binary_op, operator.sub, reverse=True)
__add__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, r)))
__radd__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, r)),
reverse=True)

__floordiv__ = partialmethod(_binary_op, operator.floordiv)
__rfloordiv__ = partialmethod(_binary_op, operator.floordiv, reverse=True)
__sub__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, -r)))
__rsub__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, -r)),
reverse=True)

__truediv__ = partialmethod(_binary_op, operator.truediv,
__floordiv__ = partialmethod(_binary_op, prim.FloorDiv)
__rfloordiv__ = partialmethod(_binary_op, prim.FloorDiv, reverse=True)

__truediv__ = partialmethod(_binary_op, prim.Quotient,
get_result_type=_truediv_result_type)
__rtruediv__ = partialmethod(_binary_op, operator.truediv,
__rtruediv__ = partialmethod(_binary_op, prim.Quotient,
get_result_type=_truediv_result_type, reverse=True)

__pow__ = partialmethod(_binary_op, operator.pow)
__rpow__ = partialmethod(_binary_op, operator.pow, reverse=True)
__pow__ = partialmethod(_binary_op, prim.Power)
__rpow__ = partialmethod(_binary_op, prim.Power, reverse=True)

__neg__ = partialmethod(_unary_op, operator.neg)

Expand Down Expand Up @@ -1488,8 +1501,7 @@ class Reshape(IndexRemappingBase):

if __debug__:
def __attrs_post_init__(self) -> None:
# FIXME: Get rid of this restriction
assert self.order == "C"
# assert self.non_equality_tags
super().__attrs_post_init__()

@property
Expand Down Expand Up @@ -1981,8 +1993,7 @@ def reshape(array: Array, newshape: Union[int, Sequence[int]],
"""
:param array: array to be reshaped
:param newshape: shape of the resulting array
:param order: ``"C"`` or ``"F"``. Layout order of the result array. Only
``"C"`` allowed for now.
:param order: ``"C"`` or ``"F"``. Layout order of the resulting array.

.. note::

Expand All @@ -2002,9 +2013,6 @@ def reshape(array: Array, newshape: Union[int, Sequence[int]],
if not all(isinstance(axis_len, INT_CLASSES) for axis_len in array.shape):
raise ValueError("reshape of arrays with symbolic lengths not allowed")

if order != "C":
raise NotImplementedError("Reshapes to a 'F'-ordered arrays")

newshape_explicit = []

for new_axislen in newshape:
Expand Down
30 changes: 29 additions & 1 deletion pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,6 +1804,33 @@ def tag_user_nodes(
# }}}


# {{{ BranchMorpher

class BranchMorpher(CopyMapper):
"""
A mapper that replaces equal segments of graphs with identical objects.
"""
def __init__(self) -> None:
super().__init__()
self.result_cache: Dict[ArrayOrNames, ArrayOrNames] = {}

def cache_key(self, expr: CachedMapperT) -> Any:
return (id(expr), expr)

# type-ignore reason: incompatible with Mapper.rec
def rec(self, expr: MappedT) -> MappedT: # type: ignore[override]
rec_expr = super().rec(expr)
try:
# type-ignored because 'result_cache' maps to ArrayOrNames
return self.result_cache[rec_expr] # type: ignore[return-value]
except KeyError:
self.result_cache[rec_expr] = rec_expr
# type-ignored because of super-class' relaxed types
return rec_expr # type: ignore[no-any-return]

# }}}


# {{{ deduplicate_data_wrappers

def _get_data_dedup_cache_key(ary: DataInterface) -> Hashable:
Expand Down Expand Up @@ -1893,8 +1920,9 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames:
len(data_wrapper_cache),
data_wrappers_encountered - len(data_wrapper_cache))

return array_or_names
return BranchMorpher()(array_or_names)

# }}}


# vim: foldmethod=marker
68 changes: 48 additions & 20 deletions pytato/transform/lower_to_index_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
AdvancedIndexInNoncontiguousAxes,
NormalizedSlice, ShapeType,
AbstractResultWithNamedArrays)
from pytato.scalar_expr import ScalarExpression, INT_CLASSES, IntegralT
from pytato.scalar_expr import ScalarExpression, INT_CLASSES
from pytato.diagnostic import CannotBeLoweredToIndexLambda
from pytato.tags import AssumeNonNegative
from pytato.transform import Mapper
Expand All @@ -51,30 +51,58 @@ def _get_reshaped_indices(expr: Reshape) -> Tuple[ScalarExpression, ...]:
assert expr.size == 1
return ()

if expr.order != "C":
raise NotImplementedError(expr.order)
if expr.order.upper() not in ["C", "F"]:
raise NotImplementedError("Order expected to be 'C' or 'F'",
f"(case insensitive) found {expr.order}")

newstrides: List[IntegralT] = [1] # reshaped array strides
for new_axis_len in reversed(expr.shape[1:]):
assert isinstance(new_axis_len, INT_CLASSES)
newstrides.insert(0, newstrides[0]*new_axis_len)
order = expr.order
oldshape = expr.array.shape
newshape = expr.shape

flattened_idx = sum(prim.Variable(f"_{i}")*stride
for i, stride in enumerate(newstrides))
# {{{ compute strides

oldstrides = [1]
oldstride_axes = (reversed(oldshape[1:]) if order == "C" else oldshape[:-1])

for ax_len in oldstride_axes:
assert isinstance(ax_len, INT_CLASSES)
oldstrides.append(oldstrides[-1]*ax_len)

newstrides = [1]
newstride_axes = (reversed(newshape[1:]) if order == "C" else newshape[:-1])

for ax_len in newstride_axes:
assert isinstance(ax_len, INT_CLASSES)
newstrides.append(newstrides[-1]*ax_len)

# }}}

# {{{ compute size tills

oldstrides: List[IntegralT] = [1] # input array strides
for axis_len in reversed(expr.array.shape[1:]):
assert isinstance(axis_len, INT_CLASSES)
oldstrides.insert(0, oldstrides[0]*axis_len)
oldsizetills = [oldshape[-1] if order == "C" else oldshape[0]]
oldsizetill_ax = (oldshape[:-1][::-1] if order == "C" else oldshape[:-1])
for ax_len in oldsizetill_ax:
oldsizetills.append(oldsizetills[-1]*ax_len)

# }}}

# {{{ if order is C, then computed info is backwards

if order == "C":
oldstrides = oldstrides[::-1]
newstrides = newstrides[::-1]
oldsizetills = oldsizetills[::-1]

# }}}

flattened_idx = sum(prim.Variable(f"_{i}")*stride
for i, stride in enumerate(newstrides))

assert isinstance(expr.array.shape[-1], INT_CLASSES)
oldsizetills = [expr.array.shape[-1]] # input array size till for axes idx
for old_axis_len in reversed(expr.array.shape[:-1]):
assert isinstance(old_axis_len, INT_CLASSES)
oldsizetills.insert(0, oldsizetills[0]*old_axis_len)
ret = tuple(
(flattened_idx % sizetill) // stride
for stride, sizetill in zip(oldstrides, oldsizetills))

return tuple(((flattened_idx % sizetill) // stride)
for stride, sizetill in zip(oldstrides, oldsizetills))
return ret


class ToIndexLambdaMixin:
Expand Down
86 changes: 41 additions & 45 deletions pytato/transform/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"""


from typing import (TYPE_CHECKING, Type, Set, Tuple, List, Dict, FrozenSet,
from typing import (TYPE_CHECKING, Type, Set, Tuple, List, Dict,
Mapping, Iterable, Any, TypeVar, cast)
from bidict import bidict
from pytato.scalar_expr import SCALAR_CLASSES
Expand All @@ -58,7 +58,7 @@
from pytato.diagnostic import UnknownIndexLambdaExpr

from pytools import UniqueNameGenerator
from pytools.tag import Tag
from pytools.tag import Tag, UniqueTag
import logging
logger = logging.getLogger(__name__)

Expand All @@ -70,6 +70,22 @@
GraphNodeT = TypeVar("GraphNodeT")


# {{{ IgnoredForPropagationTag

class AxisIgnoredForPropagationTag(UniqueTag):
"""
Used to influence equality constraints when determining which axes tags
are allowed to propagate along.

The intended use case for this is to prevent the axes of a matrix used to,
for example, differentiate a tensor of DOF data from picking up on the
unique tags attached to the axes of the tensor.
"""
pass

# }}}


# {{{ AxesTagsEquationCollector

class AxesTagsEquationCollector(Mapper):
Expand Down Expand Up @@ -167,6 +183,8 @@ def record_equations_from_axes_tags(self, ary: Array) -> None:
Records equations for *ary*\'s axis tags of type :attr:`tag_t`.
"""
for iaxis, axis in enumerate(ary.axes):
if axis.tags_of_type(AxisIgnoredForPropagationTag):
continue
lhs_var = self.get_var_for_axis(ary, iaxis)
for tag in axis.tags_of_type(self.tag_t):
rhs_var = self.get_var_for_tag(tag)
Expand Down Expand Up @@ -492,11 +510,12 @@ def map_einsum(self, expr: Einsum) -> None:
descr_to_var[EinsumElementwiseAxis(iaxis)] = self.get_var_for_axis(expr,
iaxis)

for access_descrs, arg in zip(expr.access_descriptors,
expr.args):
for access_descrs, arg in zip(expr.access_descriptors, expr.args):
for iarg_axis, descr in enumerate(access_descrs):
in_tag_var = self.get_var_for_axis(arg, iarg_axis)
if arg.axes[iarg_axis].tags_of_type(AxisIgnoredForPropagationTag):
continue

in_tag_var = self.get_var_for_axis(arg, iarg_axis)
if descr in descr_to_var:
self.record_equation(descr_to_var[descr], in_tag_var)
else:
Expand Down Expand Up @@ -556,38 +575,7 @@ def map_named_call_result(self, expr: NamedCallResult) -> Array:
# }}}


def _get_propagation_graph_from_constraints(
equations: List[Tuple[str, str]]) -> Mapping[str, FrozenSet[str]]:
from immutabledict import immutabledict
propagation_graph: Dict[str, Set[str]] = {}
for lhs, rhs in equations:
assert lhs != rhs
propagation_graph.setdefault(lhs, set()).add(rhs)
propagation_graph.setdefault(rhs, set()).add(lhs)

return immutabledict({k: frozenset(v)
for k, v in propagation_graph.items()})


def get_reachable_nodes(undirected_graph: Mapping[GraphNodeT, Iterable[GraphNodeT]],
source_node: GraphNodeT) -> FrozenSet[GraphNodeT]:
"""
Returns a :class:`frozenset` of all nodes in *undirected_graph* that are
reachable from *source_node*.
"""
nodes_visited: Set[GraphNodeT] = set()
nodes_to_visit = {source_node}
while nodes_to_visit:
current_node = nodes_to_visit.pop()
nodes_visited.add(current_node)

neighbors = undirected_graph[current_node]
nodes_to_visit.update({node
for node in neighbors
if node not in nodes_visited})

return frozenset(nodes_visited)

# {{{ AxisTagAttacher

class AxisTagAttacher(CopyMapper):
"""
Expand Down Expand Up @@ -659,6 +647,8 @@ def __call__(self, expr: ArrayOrNames) -> ArrayOrNames: # type: ignore[override
assert isinstance(result, (Array, AbstractResultWithNamedArrays))
return result

# }}}


def unify_axes_tags(
expr: ArrayOrNames,
Expand Down Expand Up @@ -693,19 +683,25 @@ def unify_axes_tags(
# Defn. A Propagation graph is a graph where nodes denote variables and an
# edge between 2 nodes denotes an equality criterion.

propagation_graph = _get_propagation_graph_from_constraints(
equations_collector.equations)
from pytools.graph import (
get_propagation_graph_from_constraints,
get_reachable_nodes
)

known_tag_vars = frozenset(equations_collector.known_tag_to_var.values())
axis_to_solved_tags: Dict[Tuple[Array, int], Set[Tag]] = {}

propagation_graph = get_propagation_graph_from_constraints(
equations_collector.equations,
)

for tag, var in equations_collector.known_tag_to_var.items():
for reachable_var in (get_reachable_nodes(propagation_graph, var)
- known_tag_vars):
axis_to_solved_tags.setdefault(
equations_collector.axis_to_var.inverse[reachable_var],
set()
).add(tag)
reachable_nodes = get_reachable_nodes(propagation_graph, var)
for reachable_var in (reachable_nodes - known_tag_vars):
axis_to_solved_tags.setdefault(
equations_collector.axis_to_var.inverse[reachable_var],
set()
).add(tag)

return AxisTagAttacher(axis_to_solved_tags,
tag_corresponding_redn_descr=unify_redn_descrs,
Expand Down
2 changes: 1 addition & 1 deletion test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1610,7 +1610,7 @@ def test_zero_size_cl_array_dedup(ctx_factory):
x4 = pt.make_data_wrapper(x_cl2)

out = pt.make_dict_of_named_arrays({"out1": 2*x1,
"out2": 2*x2,
"out2": 3*x2,
"out3": x3 + x4
})

Expand Down
Loading