Skip to content

Call BranchMorpher after dw deduplication #331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
.. autoclass:: CachedWalkMapper
.. autoclass:: TopoSortMapper
.. autoclass:: CachedMapAndCopyMapper
.. autoclass:: PostMapEqualNodeReuser
.. autofunction:: copy_dict_of_named_arrays
.. autofunction:: get_dependencies
.. autofunction:: map_and_copy
Expand Down Expand Up @@ -1569,6 +1570,48 @@ def tag_user_nodes(
# }}}


# {{{ PostMapEqualNodeReuser

class PostMapEqualNodeReuser(CopyMapper):
"""
A mapper that reuses the same object instances for equal segments of
graphs.
Comment on lines +1577 to +1578
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
A mapper that reuses the same object instances for equal segments of
graphs.
A mapper that reuses the same object instances for segments of
graphs that become equal after processing by subclasses of this mapper,
where the equality test happens immediately after a new node is created.


.. note::

The operation performed here is equivalent to that of a
:class:`CopyMapper`, in that both return a single instance for equal
:class:`pytato.Array` nodes. However, they differ at the point where
two array expressions are compared. :class:`CopyMapper` compares array
expressions before the expressions are mapped i.e. repeatedly comparing
equal array expressions but unequal instances, and because of this it
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
equal array expressions but unequal instances, and because of this it
different instances representing equal array expressions, and, because of this, it

spends super-linear time in comparing array expressions. On the other
hand, :class:`PostMapEqualNodeReuser` has linear complexity in the
number of nodes in the number of array expressions as the larger mapped
expressions already contain same instances for the predecessors,
resulting in a cheaper equality comparison overall.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what you want to say instead here is that this avoids building potentially large repeated subexpexpressions in the first place, rather than having to (expensively) merge them after the fact.

"""
def __init__(self) -> None:
super().__init__()
self.result_cache: Dict[ArrayOrNames, ArrayOrNames] = {}

def cache_key(self, expr: CachedMapperT) -> Any:
return (id(expr), expr)

# type-ignore reason: incompatible with Mapper.rec
def rec(self, expr: MappedT) -> MappedT: # type: ignore[override]
rec_expr = super().rec(expr)
try:
# type-ignored because 'result_cache' maps to ArrayOrNames
return self.result_cache[rec_expr] # type: ignore[return-value]
except KeyError:
self.result_cache[rec_expr] = rec_expr
# type-ignored because of super-class' relaxed types
return rec_expr

# }}}


# {{{ deduplicate_data_wrappers

def _get_data_dedup_cache_key(ary: DataInterface) -> Hashable:
Expand Down Expand Up @@ -1658,8 +1701,11 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames:
len(data_wrapper_cache),
data_wrappers_encountered - len(data_wrapper_cache))

return array_or_names
# many paths in the DAG might be semantically equivalent after DWs are
# deduplicated => morph them
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still: "morph" isn't good terminology.

return PostMapEqualNodeReuser()(array_or_names)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, wait. If it's used like this, I would argue that that's equivalent to just using a CopyMapper. I figured the data-wrapper-deduplicator should inherit from this to avoid building large identical subtrees in the first place.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test_post_map_equal_node_reuser_intestine to investigate whether this is equivalent to a CopyMapper, and at least in that example, it appears to be. It remains to figure out where the difference in execution time comes from.


# }}}


# vim: foldmethod=marker
2 changes: 1 addition & 1 deletion test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,7 +1556,7 @@ def test_zero_size_cl_array_dedup(ctx_factory):
x4 = pt.make_data_wrapper(x_cl2)

out = pt.make_dict_of_named_arrays({"out1": 2*x1,
"out2": 2*x2,
"out2": 3*x2,
"out3": x3 + x4
})

Expand Down
38 changes: 38 additions & 0 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from pyopencl.tools import ( # noqa
pytest_generate_tests_for_pyopencl as pytest_generate_tests)
from pytato.transform import CopyMapper, PostMapEqualNodeReuser, WalkMapper


def test_matmul_input_validation():
Expand Down Expand Up @@ -1115,6 +1116,43 @@ def test_rewrite_einsums_with_no_broadcasts():
assert pt.analysis.is_einsum_similar_to_subscript(new_expr.args[2], "ij,ik->ijk")


# {{{ test_post_map_equal_node_reuser

class _NodeInstanceCounter(WalkMapper):
def __init__(self):
self.ids = set()

def visit(self, expr):
self.ids.add(id(expr))
return True


def test_post_map_equal_node_reuser_intestine():
def construct_bad_intestine_graph(depth=10):
if depth == 0:
return pt.make_placeholder("x", shape=(10,), dtype=float)

return (
2 * construct_bad_intestine_graph(depth-1)
+ 3 * construct_bad_intestine_graph(depth-1))

def count_node_instances(graph):
nic = _NodeInstanceCounter()
nic(graph)
return len(nic.ids)

graph = construct_bad_intestine_graph()
assert count_node_instances(graph) == 4093

graph_cm = CopyMapper()(graph)
assert count_node_instances(graph_cm) == 31

graph_penr = PostMapEqualNodeReuser()(graph)
assert count_node_instances(graph_penr) == 31

# }}}


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down