Skip to content

Commit dfaaeed

Browse files
committed
Add node type counter
1 parent 1990ffa commit dfaaeed

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

pytato/analysis/__init__.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,43 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
413413
# }}}
414414

415415

416+
# {{{ NodeTypeCountMapper
417+
418+
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
419+
class NodeTypeCountMapper(CachedWalkMapper):
420+
"""
421+
Counts the number of nodes of a given type in a DAG.
422+
423+
.. attribute:: counts
424+
425+
Dictionary mapping node types to number of nodes of that type.
426+
"""
427+
428+
def __init__(self) -> None:
429+
super().__init__()
430+
self.counts = {}
431+
432+
def get_cache_key(self, expr: ArrayOrNames) -> int:
433+
return id(expr)
434+
435+
def post_visit(self, expr: Any) -> None:
436+
self.counts[type(expr)] += 1
437+
438+
439+
def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> int:
440+
"""Returns the number of nodes of each given type in DAG *outputs*."""
441+
442+
from pytato.codegen import normalize_outputs
443+
outputs = normalize_outputs(outputs)
444+
445+
ncm = NodeTypeCountMapper()
446+
ncm(outputs)
447+
448+
return ncm.counts
449+
450+
# }}}
451+
452+
416453
# {{{ CallSiteCountMapper
417454

418455
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)

0 commit comments

Comments
 (0)