Skip to content

Commit e832946

Browse files
committed
add function to compute the max node depth of a DAG
1 parent 3805500 commit e832946

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

pytato/analysis/__init__.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
.. autofunction:: is_einsum_similar_to_subscript
4949
5050
.. autofunction:: get_num_nodes
51+
.. autofunction:: get_max_node_depth
5152
5253
.. autofunction:: get_num_call_sites
5354
@@ -400,6 +401,53 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
400401
# }}}
401402

402403

404+
# {{{ NodeMaxDepthMapper
405+
406+
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
407+
class NodeMaxDepthMapper(CachedWalkMapper):
408+
"""
409+
Finds the maximum depth of a node in a DAG.
410+
411+
.. attribute:: max_depth
412+
413+
The depth of the deepest node.
414+
"""
415+
416+
def __init__(self) -> None:
417+
super().__init__()
418+
self.depth = 0
419+
self.max_depth = 0
420+
421+
# FIXME: Do I need this?
422+
# type-ignore-reason: dropped the extra `*args, **kwargs`.
423+
def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override]
424+
return id(expr)
425+
426+
def rec(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> None:
427+
"""Call the mapper method of *expr* and return the result."""
428+
self.depth += 1
429+
self.max_depth = max(self.max_depth, self.depth)
430+
431+
try:
432+
super().rec(expr, *args, **kwargs)
433+
finally:
434+
self.depth -= 1
435+
436+
437+
def get_max_node_depth(outputs: Union[Array, DictOfNamedArrays]) -> int:
438+
"""Finds the maximum depth of a node in *outputs*."""
439+
440+
from pytato.codegen import normalize_outputs
441+
outputs = normalize_outputs(outputs)
442+
443+
nmdm = NodeMaxDepthMapper()
444+
nmdm(outputs)
445+
446+
return nmdm.max_depth
447+
448+
# }}}
449+
450+
403451
# {{{ CallSiteCountMapper
404452

405453
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)

0 commit comments

Comments
 (0)