Skip to content

Commit b9edcae

Browse files
committed
Fix some linting issues.
1 parent dfaaeed commit b9edcae

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

pytato/analysis/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"""
2727

2828
from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet,
29-
TYPE_CHECKING)
29+
Type, TYPE_CHECKING)
3030
from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum,
3131
DictOfNamedArrays, NamedArray,
3232
IndexBase, IndexRemappingBase, InputArgumentBase,
@@ -426,18 +426,24 @@ class NodeTypeCountMapper(CachedWalkMapper):
426426
"""
427427

428428
def __init__(self) -> None:
429+
from collections import defaultdict
429430
super().__init__()
430-
self.counts = {}
431+
self.counts = defaultdict(int)
431432

432433
def get_cache_key(self, expr: ArrayOrNames) -> int:
433434
return id(expr)
434435

435436
def post_visit(self, expr: Any) -> None:
437+
if type(expr) not in self.counts:
438+
self.counts[type(expr)] = 0
436439
self.counts[type(expr)] += 1
437440

438441

439-
def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> int:
440-
"""Returns the number of nodes of each given type in DAG *outputs*."""
442+
def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]:
443+
"""
444+
Returns a dictionary mapping node types to node count for that type
445+
in DAG *outputs*.
446+
"""
441447

442448
from pytato.codegen import normalize_outputs
443449
outputs = normalize_outputs(outputs)

0 commit comments

Comments
 (0)