File tree 1 file changed +37
-0
lines changed
1 file changed +37
-0
lines changed Original file line number Diff line number Diff line change @@ -413,6 +413,43 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
413
413
# }}}
414
414
415
415
416
+ # {{{ NodeTypeCountMapper
417
+
418
+ @optimize_mapper (drop_args = True , drop_kwargs = True , inline_get_cache_key = True )
419
+ class NodeTypeCountMapper (CachedWalkMapper ):
420
+ """
421
+ Counts the number of nodes of a given type in a DAG.
422
+
423
+ .. attribute:: counts
424
+
425
+ Dictionary mapping node types to number of nodes of that type.
426
+ """
427
+
428
+ def __init__ (self ) -> None :
429
+ super ().__init__ ()
430
+ self .counts = {}
431
+
432
+ def get_cache_key (self , expr : ArrayOrNames ) -> int :
433
+ return id (expr )
434
+
435
+ def post_visit (self , expr : Any ) -> None :
436
+ self .counts [type (expr )] += 1
437
+
438
+
439
+ def get_num_node_types (outputs : Union [Array , DictOfNamedArrays ]) -> int :
440
+ """Returns the number of nodes of each given type in DAG *outputs*."""
441
+
442
+ from pytato .codegen import normalize_outputs
443
+ outputs = normalize_outputs (outputs )
444
+
445
+ ncm = NodeTypeCountMapper ()
446
+ ncm (outputs )
447
+
448
+ return ncm .counts
449
+
450
+ # }}}
451
+
452
+
416
453
# {{{ CallSiteCountMapper
417
454
418
455
@optimize_mapper (drop_args = True , drop_kwargs = True , inline_get_cache_key = True )
You can’t perform that action at this time.
0 commit comments