|
32 | 32 | IndexBase, IndexRemappingBase, InputArgumentBase,
|
33 | 33 | ShapeType)
|
34 | 34 | from pytato.function import FunctionDefinition, Call
|
35 |
| -from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper |
| 35 | +from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper, MappedT |
36 | 36 | from pytato.loopy import LoopyCall
|
37 | 37 | from pymbolic.mapper.optimize import optimize_mapper
|
38 | 38 | from pytools import memoize_method
|
|
48 | 48 | .. autofunction:: is_einsum_similar_to_subscript
|
49 | 49 |
|
50 | 50 | .. autofunction:: get_num_nodes
|
| 51 | +.. autofunction:: get_max_node_depth |
51 | 52 |
|
52 | 53 | .. autofunction:: get_num_call_sites
|
53 | 54 |
|
@@ -400,6 +401,55 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
|
400 | 401 | # }}}
|
401 | 402 |
|
402 | 403 |
|
| 404 | +# {{{ NodeMaxDepthMapper |
| 405 | + |
| 406 | +@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) |
| 407 | +class NodeMaxDepthMapper(CachedWalkMapper): |
| 408 | + """ |
| 409 | + Finds the maximum depth of a node in a DAG. |
| 410 | +
|
| 411 | + .. attribute:: max_depth |
| 412 | +
|
| 413 | + The depth of the deepest node. |
| 414 | + """ |
| 415 | + |
| 416 | + def __init__(self) -> None: |
| 417 | + super().__init__() |
| 418 | + self.depth = 0 |
| 419 | + self.max_depth = 0 |
| 420 | + |
| 421 | + # FIXME: Do I need this? |
| 422 | + # type-ignore-reason: dropped the extra `*args, **kwargs`. |
| 423 | + def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] |
| 424 | + return id(expr) |
| 425 | + |
| 426 | + def rec(self, expr: MappedT, *args: Any, **kwargs: Any) -> Any: |
| 427 | + """Call the mapper method of *expr* and return the result.""" |
| 428 | + self.depth += 1 |
| 429 | + self.max_depth = max(self.max_depth, self.depth) |
| 430 | + |
| 431 | + try: |
| 432 | + result = super().rec(expr, *args, **kwargs) |
| 433 | + finally: |
| 434 | + self.depth -= 1 |
| 435 | + |
| 436 | + return result |
| 437 | + |
| 438 | + |
| 439 | +def get_max_node_depth(outputs: Union[Array, DictOfNamedArrays]) -> int: |
| 440 | + """Finds the maximum depth of a node in *outputs*.""" |
| 441 | + |
| 442 | + from pytato.codegen import normalize_outputs |
| 443 | + outputs = normalize_outputs(outputs) |
| 444 | + |
| 445 | + nmdm = NodeMaxDepthMapper() |
| 446 | + nmdm(outputs) |
| 447 | + |
| 448 | + return nmdm.max_depth |
| 449 | + |
| 450 | +# }}} |
| 451 | + |
| 452 | + |
403 | 453 | # {{{ CallSiteCountMapper
|
404 | 454 |
|
405 | 455 | @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
|
|
0 commit comments