Skip to content

Commit 633168e

Browse files
committed
add function to compute the max node depth of a DAG
1 parent 3805500 commit 633168e

File tree

1 file changed

+51
-1
lines changed

1 file changed

+51
-1
lines changed

pytato/analysis/__init__.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
IndexBase, IndexRemappingBase, InputArgumentBase,
3333
ShapeType)
3434
from pytato.function import FunctionDefinition, Call
35-
from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper
35+
from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper, MappedT
3636
from pytato.loopy import LoopyCall
3737
from pymbolic.mapper.optimize import optimize_mapper
3838
from pytools import memoize_method
@@ -48,6 +48,7 @@
4848
.. autofunction:: is_einsum_similar_to_subscript
4949
5050
.. autofunction:: get_num_nodes
51+
.. autofunction:: get_max_node_depth
5152
5253
.. autofunction:: get_num_call_sites
5354
@@ -400,6 +401,55 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
400401
# }}}
401402

402403

404+
# {{{ NodeMaxDepthMapper
405+
406+
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
407+
class NodeMaxDepthMapper(CachedWalkMapper):
408+
"""
409+
Finds the maximum depth of a node in a DAG.
410+
411+
.. attribute:: max_depth
412+
413+
The depth of the deepest node.
414+
"""
415+
416+
def __init__(self) -> None:
417+
super().__init__()
418+
self.depth = 0
419+
self.max_depth = 0
420+
421+
# FIXME: Do I need this?
422+
# type-ignore-reason: dropped the extra `*args, **kwargs`.
423+
def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override]
424+
return id(expr)
425+
426+
def rec(self, expr: MappedT, *args: Any, **kwargs: Any) -> Any:
427+
"""Call the mapper method of *expr* and return the result."""
428+
self.depth += 1
429+
self.max_depth = max(self.max_depth, self.depth)
430+
431+
try:
432+
result = super().rec(expr, *args, **kwargs)
433+
finally:
434+
self.depth -= 1
435+
436+
return result
437+
438+
439+
def get_max_node_depth(outputs: Union[Array, DictOfNamedArrays]) -> int:
440+
"""Finds the maximum depth of a node in *outputs*."""
441+
442+
from pytato.codegen import normalize_outputs
443+
outputs = normalize_outputs(outputs)
444+
445+
nmdm = NodeMaxDepthMapper()
446+
nmdm(outputs)
447+
448+
return nmdm.max_depth
449+
450+
# }}}
451+
452+
403453
# {{{ CallSiteCountMapper
404454

405455
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)

0 commit comments

Comments
 (0)