|
31 | 31 | DictOfNamedArrays, NamedArray,
|
32 | 32 | IndexBase, IndexRemappingBase, InputArgumentBase,
|
33 | 33 | ShapeType)
|
| 34 | +from pytato.function import FunctionDefinition, Call |
34 | 35 | from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper
|
35 | 36 | from pytato.loopy import LoopyCall
|
36 | 37 | from pymbolic.mapper.optimize import optimize_mapper
|
| 38 | +from pytools import memoize_method |
37 | 39 |
|
38 | 40 | if TYPE_CHECKING:
|
39 | 41 | from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
|
|
47 | 49 |
|
48 | 50 | .. autofunction:: get_num_nodes
|
49 | 51 |
|
| 52 | +.. autofunction:: get_num_call_sites |
| 53 | +
|
50 | 54 | .. autoclass:: DirectPredecessorsGetter
|
51 | 55 | """
|
52 | 56 |
|
@@ -388,3 +392,57 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
|
388 | 392 | return ncm.count
|
389 | 393 |
|
390 | 394 | # }}}
|
| 395 | + |
| 396 | + |
| 397 | +# {{{ CallSiteCountMapper |
| 398 | + |
| 399 | +@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) |
| 400 | +class CallSiteCountMapper(CachedWalkMapper): |
| 401 | + """ |
| 402 | + Counts the number of nodes in a DAG. |
| 403 | +
|
| 404 | + .. attribute:: count |
| 405 | +
|
| 406 | + The number of nodes. |
| 407 | + """ |
| 408 | + |
| 409 | + def __init__(self) -> None: |
| 410 | + super().__init__() |
| 411 | + self.count = 0 |
| 412 | + |
| 413 | + # type-ignore-reason: dropped the extra `*args, **kwargs`. |
| 414 | + def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] |
| 415 | + return id(expr) |
| 416 | + |
| 417 | + @memoize_method |
| 418 | + def map_function_definition(self, /, expr: FunctionDefinition, |
| 419 | + *args: Any, **kwargs: Any) -> None: |
| 420 | + if not self.visit(expr): |
| 421 | + return |
| 422 | + |
| 423 | + new_mapper = self.clone_for_callee() |
| 424 | + for subexpr in expr.returns.values(): |
| 425 | + new_mapper(subexpr, *args, **kwargs) |
| 426 | + |
| 427 | + self.count += new_mapper.count |
| 428 | + |
| 429 | + self.post_visit(expr, *args, **kwargs) |
| 430 | + |
| 431 | + # type-ignore-reason: dropped the extra `*args, **kwargs`. |
| 432 | + def post_visit(self, expr: Any) -> None: # type: ignore[override] |
| 433 | + if isinstance(expr, Call): |
| 434 | + self.count += 1 |
| 435 | + |
| 436 | + |
| 437 | +def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: |
| 438 | + """Returns the number of nodes in DAG *outputs*.""" |
| 439 | + |
| 440 | + from pytato.codegen import normalize_outputs |
| 441 | + outputs = normalize_outputs(outputs) |
| 442 | + |
| 443 | + cscm = CallSiteCountMapper() |
| 444 | + cscm(outputs) |
| 445 | + |
| 446 | + return cscm.count |
| 447 | + |
| 448 | +# }}} |
0 commit comments