|
73 | 73 | .. autoclass:: CachedWalkMapper
|
74 | 74 | .. autoclass:: TopoSortMapper
|
75 | 75 | .. autoclass:: CachedMapAndCopyMapper
|
| 76 | +.. autoclass:: PostMapEqualNodeReuser |
76 | 77 | .. autofunction:: copy_dict_of_named_arrays
|
77 | 78 | .. autofunction:: get_dependencies
|
78 | 79 | .. autofunction:: map_and_copy
|
@@ -1569,6 +1570,48 @@ def tag_user_nodes(
|
1569 | 1570 | # }}}
|
1570 | 1571 |
|
1571 | 1572 |
|
| 1573 | +# {{{ PostMapEqualNodeReuser |
| 1574 | + |
| 1575 | +class PostMapEqualNodeReuser(CopyMapper): |
| 1576 | + """ |
| 1577 | + A mapper that reuses the same object instances for equal segments of |
| 1578 | + graphs. |
| 1579 | +
|
| 1580 | + .. note:: |
| 1581 | +
|
| 1582 | + The operation performed here is equivalent to that of a |
| 1583 | + :class:`CopyMapper`, in that both return a single instance for equal |
| 1584 | + :class:`pytato.Array` nodes. However, they differ at the point where |
| 1585 | + two array expressions are compared. :class:`CopyMapper` compares array |
| 1586 | + expressions before the expressions are mapped i.e. repeatedly comparing |
| 1587 | + equal array expressions but unequal instances, and because of this it |
| 1588 | + spends super-linear time in comparing array expressions. On the other |
| 1589 | + hand, :class:`PostMapEqualNodeReuser` has linear complexity in the |
| 1590 | + number of nodes in the number of array expressions as the larger mapped |
| 1591 | + expressions already contain same instances for the predecessors, |
| 1592 | + resulting in a cheaper equality comparison overall. |
| 1593 | + """ |
| 1594 | + def __init__(self) -> None: |
| 1595 | + super().__init__() |
| 1596 | + self.result_cache: Dict[ArrayOrNames, ArrayOrNames] = {} |
| 1597 | + |
| 1598 | + def cache_key(self, expr: CachedMapperT) -> Any: |
| 1599 | + return (id(expr), expr) |
| 1600 | + |
| 1601 | + # type-ignore reason: incompatible with Mapper.rec |
| 1602 | + def rec(self, expr: MappedT) -> MappedT: # type: ignore[override] |
| 1603 | + rec_expr = super().rec(expr) |
| 1604 | + try: |
| 1605 | + # type-ignored because 'result_cache' maps to ArrayOrNames |
| 1606 | + return self.result_cache[rec_expr] # type: ignore[return-value] |
| 1607 | + except KeyError: |
| 1608 | + self.result_cache[rec_expr] = rec_expr |
| 1609 | + # type-ignored because of super-class' relaxed types |
| 1610 | + return rec_expr |
| 1611 | + |
| 1612 | +# }}} |
| 1613 | + |
| 1614 | + |
1572 | 1615 | # {{{ deduplicate_data_wrappers
|
1573 | 1616 |
|
1574 | 1617 | def _get_data_dedup_cache_key(ary: DataInterface) -> Hashable:
|
@@ -1658,8 +1701,11 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames:
|
1658 | 1701 | len(data_wrapper_cache),
|
1659 | 1702 | data_wrappers_encountered - len(data_wrapper_cache))
|
1660 | 1703 |
|
1661 |
| - return array_or_names |
| 1704 | + # many paths in the DAG might be semantically equivalent after DWs are |
| 1705 | + # deduplicated => morph them |
| 1706 | + return PostMapEqualNodeReuser()(array_or_names) |
1662 | 1707 |
|
1663 | 1708 | # }}}
|
1664 | 1709 |
|
| 1710 | + |
1665 | 1711 | # vim: foldmethod=marker
|
0 commit comments