Skip to content

Commit 3a01309

Browse files
committed
Call PostMapEqualNodesReuser after dw deduplication
1 parent ac05808 commit 3a01309

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

pytato/transform/__init__.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
.. autoclass:: CachedWalkMapper
7474
.. autoclass:: TopoSortMapper
7575
.. autoclass:: CachedMapAndCopyMapper
76+
.. autoclass:: PostMapEqualNodeReuser
7677
.. autofunction:: copy_dict_of_named_arrays
7778
.. autofunction:: get_dependencies
7879
.. autofunction:: map_and_copy
@@ -1569,6 +1570,48 @@ def tag_user_nodes(
15691570
# }}}
15701571

15711572

1573+
# {{{ PostMapEqualNodeReuser
1574+
1575+
class PostMapEqualNodeReuser(CopyMapper):
1576+
"""
1577+
A mapper that reuses the same object instances for equal segments of
1578+
graphs.
1579+
1580+
.. note::
1581+
1582+
The operation performed here is equivalent to that of a
1583+
:class:`CopyMapper`, in that both return a single instance for equal
1584+
:class:`pytato.Array` nodes. However, they differ at the point where
1585+
two array expressions are compared. :class:`CopyMapper` compares array
1586+
expressions before the expressions are mapped i.e. repeatedly comparing
1587+
equal array expressions but unequal instances, and because of this it
1588+
spends super-linear time in comparing array expressions. On the other
1589+
hand, :class:`PostMapEqualNodeReuser` has linear complexity in the
1590+
number of nodes in the number of array expressions as the larger mapped
1591+
expressions already contain same instances for the predecessors,
1592+
resulting in a cheaper equality comparison overall.
1593+
"""
1594+
def __init__(self) -> None:
1595+
super().__init__()
1596+
self.result_cache: Dict[ArrayOrNames, ArrayOrNames] = {}
1597+
1598+
def cache_key(self, expr: CachedMapperT) -> Any:
1599+
return (id(expr), expr)
1600+
1601+
# type-ignore reason: incompatible with Mapper.rec
1602+
def rec(self, expr: MappedT) -> MappedT: # type: ignore[override]
1603+
rec_expr = super().rec(expr)
1604+
try:
1605+
# type-ignored because 'result_cache' maps to ArrayOrNames
1606+
return self.result_cache[rec_expr] # type: ignore[return-value]
1607+
except KeyError:
1608+
self.result_cache[rec_expr] = rec_expr
1609+
# type-ignored because of super-class' relaxed types
1610+
return rec_expr
1611+
1612+
# }}}
1613+
1614+
15721615
# {{{ deduplicate_data_wrappers
15731616

15741617
def _get_data_dedup_cache_key(ary: DataInterface) -> Hashable:
@@ -1658,8 +1701,11 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames:
16581701
len(data_wrapper_cache),
16591702
data_wrappers_encountered - len(data_wrapper_cache))
16601703

1661-
return array_or_names
1704+
# many paths in the DAG might be semantically equivalent after DWs are
1705+
# deduplicated => morph them
1706+
return PostMapEqualNodeReuser()(array_or_names)
16621707

16631708
# }}}
16641709

1710+
16651711
# vim: foldmethod=marker

test/test_codegen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1556,7 +1556,7 @@ def test_zero_size_cl_array_dedup(ctx_factory):
15561556
x4 = pt.make_data_wrapper(x_cl2)
15571557

15581558
out = pt.make_dict_of_named_arrays({"out1": 2*x1,
1559-
"out2": 2*x2,
1559+
"out2": 3*x2,
15601560
"out3": x3 + x4
15611561
})
15621562

0 commit comments

Comments
 (0)