Skip to content

Commit 221de48

Browse files
committed
--amend
1 parent dfaaeed commit 221de48

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

pytato/analysis/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,18 +426,24 @@ class NodeTypeCountMapper(CachedWalkMapper):
426426
"""
427427

428428
def __init__(self) -> None:
429+
from collections import defaultdict
429430
super().__init__()
430-
self.counts = {}
431+
self.counts = defaultdict(int)
431432

432433
def get_cache_key(self, expr: ArrayOrNames) -> int:
433434
return id(expr)
434435

435436
def post_visit(self, expr: Any) -> None:
437+
if type(expr) not in counts:
438+
self.counts[type(expr)] = 0
436439
self.counts[type(expr)] += 1
437440

438441

439-
def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> int:
440-
"""Returns the number of nodes of each given type in DAG *outputs*."""
442+
def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]:
443+
"""
444+
Returns a dictionary mapping node types to node count for that type
445+
in DAG *outputs*.
446+
"""
441447

442448
from pytato.codegen import normalize_outputs
443449
outputs = normalize_outputs(outputs)

0 commit comments

Comments
 (0)