Skip to content

add materialization counter #333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 78 additions & 5 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
"""

from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet,
TYPE_CHECKING)
TYPE_CHECKING, Iterable)
from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum,
DictOfNamedArrays, NamedArray,
IndexBase, IndexRemappingBase, InputArgumentBase,
ShapeType)
from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper
from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper, CombineMapper
from pytato.loopy import LoopyCall
from pytools.tag import Tag
from pytato.tags import ImplStored

if TYPE_CHECKING:
from pytato.distributed import DistributedRecv, DistributedSendRefHolder
Expand All @@ -47,6 +49,11 @@
.. autofunction:: get_num_nodes

.. autoclass:: DirectPredecessorsGetter

.. autoclass:: TagCountMapper
.. autofunction:: get_num_tags_of_type

.. autofunction:: get_num_materialized
"""


Expand Down Expand Up @@ -371,12 +378,78 @@ def post_visit(self, expr: Any) -> None:
def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
"""Returns the number of nodes in DAG *outputs*."""

from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)

ncm = NodeCountMapper()
ncm(outputs)

return ncm.count

# }}}


# {{{ TagCountMapper

class TagCountMapper(CombineMapper[int]):
"""
Returns the number of nodes in a DAG that are tagged with all the tags in *tags*.
"""

def __init__(self, tags: Union[Tag, Iterable[Tag]]) -> None:
super().__init__()
if isinstance(tags, Tag):
tags = frozenset((tags,))
elif not isinstance(tags, frozenset):
tags = frozenset(tags)
self._tags = tags

def combine(self, *args: int) -> int:
return sum(args)

# type-ignore reason: incompatible return type with super class
def rec(self, expr: ArrayOrNames) -> int: # type: ignore
if expr in self.cache:
return self.cache[expr]

if isinstance(expr, Array) and self._tags <= expr.tags:
result = 1 + super().rec(expr)
else:
result = 0 + super().rec(expr)

self.cache[expr] = 0
return result


def get_num_tags_of_type(
outputs: Union[Array, DictOfNamedArrays],
tags: Union[Tag, Iterable[Tag]]) -> int:
"""Returns the number of nodes in DAG *outputs* that are tagged with
all the tags in *tags*."""

tcm = TagCountMapper(tags)

return tcm(outputs)

# }}}


def get_num_materialized(outputs: Union[Array, DictOfNamedArrays]) \
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_num_materialized(outputs: Union[Array, DictOfNamedArrays]) \
def get_num_materialized_predecessors(outputs: Union[Array, DictOfNamedArrays]) \

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My main issue with this approach is that it has at least quadratic complexity. A possible improvement would be a CombineMapper that simply counts the materialized predecessors for each node. I.e. combine is a sum, and each node returns 1 if it's materialized and uses inherited behavior (reduce ("combine") over predecessors).

-> Dict[ArrayOrNames, int]:
"""Returns the number of materialized nodes each node in *outputs* depends on."""
from pytato.transform import rec_get_all_user_nodes
users = rec_get_all_user_nodes(outputs)

def is_materialized(expr: ArrayOrNames) -> bool:
if (isinstance(expr, Array) and
any(isinstance(tag, ImplStored) for tag in expr.tags)):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
any(isinstance(tag, ImplStored) for tag in expr.tags)):
expr.tags_of_type(ImplStored)):

return True
else:
return False

res: Dict[ArrayOrNames, int] = {}

for node in users.keys():
if is_materialized(node):
for user in users[node]:
res.setdefault(user, 0)
res[user] += 1

return res
35 changes: 29 additions & 6 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
.. autofunction:: reverse_graph
.. autofunction:: tag_user_nodes
.. autofunction:: rec_get_user_nodes
.. autofunction:: rec_get_all_user_nodes

.. autofunction:: deduplicate_data_wrappers

Expand Down Expand Up @@ -1030,8 +1031,8 @@ def _materialize_if_mpms(expr: Array,
) -> MPMSMaterializerAccumulator:
"""
Returns an instance of :class:`MPMSMaterializerAccumulator`, that
materializes *expr* if it has more than 1 successors and more than 1
materialized predecessors.
materializes *expr* if it has more than 1 successor and more than 1
materialized predecessor.
"""
from functools import reduce

Expand Down Expand Up @@ -1250,8 +1251,8 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays:
.. note::

- MPMS materialization strategy is a greedy materialization algorithm in
which any node with more than 1 materialized predecessors and more than
1 successors is materialized.
which any node with more than 1 materialized predecessor and more than
1 successor is materialized.
- Materializing here corresponds to tagging a node with
:class:`~pytato.tags.ImplStored`.
- Does not attempt to materialize sub-expressions in
Expand Down Expand Up @@ -1292,13 +1293,21 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays:
====== ======== =======

"""
from pytato.analysis import get_nusers
from pytato.analysis import get_nusers, get_num_nodes, get_num_tags_of_type
materializer = MPMSMaterializer(get_nusers(expr))
new_data = {}
for name, ary in expr.items():
new_data[name] = materializer(ary.expr).expr

return DictOfNamedArrays(new_data)
res = DictOfNamedArrays(new_data)

from pytato import DEBUG_ENABLED
if DEBUG_ENABLED:
logger.info("materialize_with_mpms: materialized "
f"{get_num_tags_of_type(res, ImplStored())} out of "
f"{get_num_nodes(res)} nodes")

return res

# }}}

Expand Down Expand Up @@ -1501,6 +1510,20 @@ def rec_get_user_nodes(expr: ArrayOrNames,
return _recursively_get_all_users(users, node)


def rec_get_all_user_nodes(expr: ArrayOrNames) \
-> Dict[ArrayOrNames, FrozenSet[ArrayOrNames]]:
"""
Returns all direct and indirect users of all nodes in *expr*.
"""
users = get_users(expr)

res = {}

for node in users.keys():
res[node] = _recursively_get_all_users(users, node)
return res


def tag_user_nodes(
graph: Mapping[ArrayOrNames, Set[ArrayOrNames]],
tag: Any,
Expand Down
58 changes: 56 additions & 2 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,7 @@ def test_nodecountmapper():
axis_len=axis_len, use_numpy=False)
dag = make_random_dag(rdagc)

# Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays.
assert get_num_nodes(dag)-1 == len(pt.transform.DependencyMapper()(dag))
assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag))


def test_rec_get_user_nodes():
Expand Down Expand Up @@ -884,6 +883,41 @@ def test_adv_indexing_into_zero_long_axes():
# }}}


def test_tagcountmapper():
from testlib import RandomDAGContext, make_random_dag
from pytato.analysis import get_num_tags_of_type, get_num_nodes
from pytools.tag import Tag

class NonExistentTag(Tag):
pass

class ExistentTag(Tag):
pass

seed = 199
axis_len = 3

rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed),
axis_len=axis_len, use_numpy=False)

out = make_random_dag(rdagc_pt).tagged(ExistentTag())

dag = pt.make_dict_of_named_arrays({"out": out})

# get_num_nodes() returns an extra DictOfNamedArrays node
assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag)-1

assert get_num_tags_of_type(dag, NonExistentTag()) == 0
assert get_num_tags_of_type(dag, frozenset((ExistentTag(),))) == 1
assert get_num_tags_of_type(dag,
frozenset((ExistentTag(), NonExistentTag()))) == 0

a = pt.make_data_wrapper(np.arange(27))
dag = a+a+a+a+a+a+a+a

assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag)


def test_expand_dims_input_validate():
a = pt.make_placeholder("x", (10, 4), dtype="float64")

Expand All @@ -901,6 +935,26 @@ def test_expand_dims_input_validate():
pt.expand_dims(a, -4)


def test_materialization_counter():
from pytato.analysis import get_num_materialized
from testlib import RandomDAGContext, make_random_dag

seed = 1999
axis_len = 4

rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed),
axis_len=axis_len, use_numpy=False)

out = make_random_dag(rdagc_pt)

res = pt.make_dict_of_named_arrays({"out": out})
res = pt.transform.materialize_with_mpms(res)

r = get_num_materialized(res)

assert max([v for v in r.values()]) == 6


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down