|
26 | 26 | """
|
27 | 27 |
|
28 | 28 | from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet,
|
29 |
| - TYPE_CHECKING) |
| 29 | + Type, TYPE_CHECKING) |
30 | 30 | from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum,
|
31 | 31 | DictOfNamedArrays, NamedArray,
|
32 | 32 | IndexBase, IndexRemappingBase, InputArgumentBase,
|
@@ -426,18 +426,24 @@ class NodeTypeCountMapper(CachedWalkMapper):
|
426 | 426 | """
|
427 | 427 |
|
428 | 428 | def __init__(self) -> None:
|
| 429 | + from collections import defaultdict |
429 | 430 | super().__init__()
|
430 |
| - self.counts = {} |
| 431 | + self.counts = defaultdict(int) |
431 | 432 |
|
432 | 433 | def get_cache_key(self, expr: ArrayOrNames) -> int:
|
433 | 434 | return id(expr)
|
434 | 435 |
|
435 | 436 | def post_visit(self, expr: Any) -> None:
|
| 437 | + if type(expr) not in self.counts: |
| 438 | + self.counts[type(expr)] = 0 |
436 | 439 | self.counts[type(expr)] += 1
|
437 | 440 |
|
438 | 441 |
|
439 |
| -def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> int: |
440 |
| - """Returns the number of nodes of each given type in DAG *outputs*.""" |
| 442 | +def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]: |
| 443 | + """ |
| 444 | + Returns a dictionary mapping node types to node count for that type |
| 445 | + in DAG *outputs*. |
| 446 | + """ |
441 | 447 |
|
442 | 448 | from pytato.codegen import normalize_outputs
|
443 | 449 | outputs = normalize_outputs(outputs)
|
|
0 commit comments