diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 6db21e863..4cdde2c82 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -66,6 +66,7 @@ .. autofunction:: is_einsum_similar_to_subscript .. autofunction:: get_num_nodes +.. autofunction:: get_max_node_depth .. autofunction:: get_node_type_counts @@ -545,6 +546,57 @@ def get_node_multiplicities( # }}} +# {{{ NodeMaxDepthMapper + +@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) +class NodeMaxDepthMapper(CachedWalkMapper): + """ + Finds the maximum depth of a node in a DAG. + + .. attribute:: max_depth + + The depth of the deepest node. + """ + + def __init__(self) -> None: + super().__init__() + # Want the first rec() call to increment to 0, so start at -1 + self.depth = -1 + self.max_depth = -1 + + # FIXME: Do I need this? + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] + return id(expr) + + def rec(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> None: + """Call the mapper method of *expr* and return the result.""" + if isinstance(expr, DictOfNamedArrays): + super().rec(expr, *args, **kwargs) + else: + self.depth += 1 + self.max_depth = max(self.max_depth, self.depth) + + try: + super().rec(expr, *args, **kwargs) + finally: + self.depth -= 1 + + +def get_max_node_depth(outputs: Union[Array, DictOfNamedArrays]) -> int: + """Finds the maximum depth of a node in *outputs*.""" + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + nmdm = NodeMaxDepthMapper() + nmdm(outputs) + + return nmdm.max_depth + +# }}} + + # {{{ CallSiteCountMapper @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) diff --git a/test/test_pytato.py b/test/test_pytato.py index 3d16e28d9..da75bcc89 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -765,6 +765,19 @@ def test_large_dag_with_duplicates_count(): dag, count_duplicates=False) +def test_nodemaxdepthmapper(): + from pytato.analysis import get_max_node_depth + + x = pt.make_placeholder("x", shape=(10, 4), dtype=np.float64) + + assert get_max_node_depth(x) == 0 + + for _ in range(10): + x = x + 1 + + assert get_max_node_depth(x) == 10 + + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4))