File tree 1 file changed +9
-3
lines changed
1 file changed +9
-3
lines changed Original file line number Diff line number Diff line change @@ -426,18 +426,24 @@ class NodeTypeCountMapper(CachedWalkMapper):
426
426
"""
427
427
428
428
def __init__ (self ) -> None :
429
+ from collections import defaultdict
429
430
super ().__init__ ()
430
- self .counts = {}
431
+ self .counts = defaultdict ( int )
431
432
432
433
def get_cache_key (self , expr : ArrayOrNames ) -> int :
433
434
return id (expr )
434
435
435
436
def post_visit (self , expr : Any ) -> None :
437
+ if type (expr ) not in counts :
438
+ self .counts [type (expr )] = 0
436
439
self .counts [type (expr )] += 1
437
440
438
441
439
- def get_num_node_types (outputs : Union [Array , DictOfNamedArrays ]) -> int :
440
- """Returns the number of nodes of each given type in DAG *outputs*."""
442
+ def get_num_node_types (outputs : Union [Array , DictOfNamedArrays ]) -> Dict [Type , int ]:
443
+ """
444
+ Returns a dictionary mapping node types to node count for that type
445
+ in DAG *outputs*.
446
+ """
441
447
442
448
from pytato .codegen import normalize_outputs
443
449
outputs = normalize_outputs (outputs )
You can’t perform that action at this time.
0 commit comments