Skip to content

Commit e5cfb89

Browse files
kaushikcfdinducer
authored andcommitted
Add pt.analysis.get_num_call_sites
1 parent 19fd395 commit e5cfb89

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

pytato/analysis/__init__.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@
3131
DictOfNamedArrays, NamedArray,
3232
IndexBase, IndexRemappingBase, InputArgumentBase,
3333
ShapeType)
34+
from pytato.function import FunctionDefinition, Call
3435
from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper
3536
from pytato.loopy import LoopyCall
3637
from pymbolic.mapper.optimize import optimize_mapper
38+
from pytools import memoize_method
3739

3840
if TYPE_CHECKING:
3941
from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
@@ -47,6 +49,8 @@
4749
4850
.. autofunction:: get_num_nodes
4951
52+
.. autofunction:: get_num_call_sites
53+
5054
.. autoclass:: DirectPredecessorsGetter
5155
"""
5256

@@ -388,3 +392,57 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
388392
return ncm.count
389393

390394
# }}}
395+
396+
397+
# {{{ CallSiteCountMapper
398+
399+
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
400+
class CallSiteCountMapper(CachedWalkMapper):
401+
"""
402+
Counts the number of nodes in a DAG.
403+
404+
.. attribute:: count
405+
406+
The number of nodes.
407+
"""
408+
409+
def __init__(self) -> None:
410+
super().__init__()
411+
self.count = 0
412+
413+
# type-ignore-reason: dropped the extra `*args, **kwargs`.
414+
def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override]
415+
return id(expr)
416+
417+
@memoize_method
418+
def map_function_definition(self, /, expr: FunctionDefinition,
419+
*args: Any, **kwargs: Any) -> None:
420+
if not self.visit(expr):
421+
return
422+
423+
new_mapper = self.clone_for_callee()
424+
for subexpr in expr.returns.values():
425+
new_mapper(subexpr, *args, **kwargs)
426+
427+
self.count += new_mapper.count
428+
429+
self.post_visit(expr, *args, **kwargs)
430+
431+
# type-ignore-reason: dropped the extra `*args, **kwargs`.
432+
def post_visit(self, expr: Any) -> None: # type: ignore[override]
433+
if isinstance(expr, Call):
434+
self.count += 1
435+
436+
437+
def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int:
438+
"""Returns the number of nodes in DAG *outputs*."""
439+
440+
from pytato.codegen import normalize_outputs
441+
outputs = normalize_outputs(outputs)
442+
443+
cscm = CallSiteCountMapper()
444+
cscm(outputs)
445+
446+
return cscm.count
447+
448+
# }}}

0 commit comments

Comments
 (0)