@@ -1076,8 +1076,15 @@ def visit_For(self, node: ast.For) -> None:
10761076 self .generic_visit (node )
10771077
10781078
1079- def _count_device_loads (device_ir : DeviceIR ) -> int :
1080- """Count the number of load operations in all device code for eviction policy tuning."""
1079+ def _count_device_loads_and_stores (device_ir : DeviceIR ) -> tuple [int , int , int ]:
1080+ """Count the number of load and store operations in device code for autotuning.
1081+
1082+ Returns:
1083+ tuple[int, int, int]: (total_load_count, loads_without_eviction_policy, store_count)
1084+ - total_load_count: all loads (for indexing tunable)
1085+ - loads_without_eviction_policy: loads that need eviction policy tuning
1086+ - store_count: all stores (for indexing tunable)
1087+ """
10811088 from ..language import memory_ops
10821089
10831090 # Build set of rolled graph IDs to exclude (these are duplicates)
@@ -1087,31 +1094,47 @@ def _count_device_loads(device_ir: DeviceIR) -> int:
10871094 if info .new_graph_id is not None
10881095 }
10891096
1090- load_count = 0
1097+ total_load_count = 0
1098+ loads_without_eviction_policy = 0
1099+ store_count = 0
1100+
10911101 # Walk all graphs except rolled duplicates
10921102 for graph_info in device_ir .graphs :
10931103 if graph_info .graph_id in rolled_graph_ids :
10941104 continue
10951105
10961106 for node in graph_info .graph .nodes :
1097- # Check if this is a load operation
1098- if node .op == "call_function" and node .target is memory_ops .load :
1099- # Only count loads without explicit eviction policy
1100- # (user can still specify eviction_policy to override tuning)
1101- # Check kwargs first, then check if 4th arg (eviction_policy) is None
1102- eviction_policy_arg = node .kwargs .get ("eviction_policy" )
1103- if eviction_policy_arg is None :
1104- # Check if eviction_policy was passed as positional arg (index 3)
1105- if len (node .args ) >= 4 :
1106- eviction_policy_arg = node .args [3 ]
1107+ if node .op == "call_function" :
1108+ # Check if this is a load operation
1109+ if node .target is memory_ops .load :
1110+ total_load_count += 1
1111+ # Check if this load needs eviction policy tuning
1112+ # (user can still specify eviction_policy to override tuning)
1113+ eviction_policy_arg = node .kwargs .get ("eviction_policy" )
11071114 if eviction_policy_arg is None :
1108- load_count += 1
1109- return load_count
1110-
1111-
1112- def _register_load_tunables (load_count : int ) -> None :
1113- """Register list-based tunables (indexing, eviction policies) for all device loads."""
1114- if load_count == 0 :
1115+ # Check if eviction_policy was passed as positional arg (index 3)
1116+ if len (node .args ) >= 4 :
1117+ eviction_policy_arg = node .args [3 ]
1118+ if eviction_policy_arg is None :
1119+ loads_without_eviction_policy += 1
1120+ # Check if this is a store operation
1121+ elif node .target is memory_ops .store :
1122+ store_count += 1
1123+
1124+ return total_load_count , loads_without_eviction_policy , store_count
1125+
1126+
1127+ def _register_load_store_tunables (
1128+ total_load_count : int , loads_without_eviction_policy : int , store_count : int
1129+ ) -> None :
1130+ """Register list-based tunables (indexing, eviction policies) for all device loads and stores.
1131+
1132+ Args:
1133+ total_load_count: Total number of loads (for indexing tunable)
1134+ loads_without_eviction_policy: Number of loads that need eviction policy tuning
1135+ store_count: Total number of stores (for indexing tunable)
1136+ """
1137+ if total_load_count == 0 and store_count == 0 :
11151138 return
11161139
11171140 from ..autotuner .config_fragment import EnumFragment
@@ -1120,13 +1143,21 @@ def _register_load_tunables(load_count: int) -> None:
11201143 from ..autotuner .config_spec import ConfigSpec
11211144
11221145 env = CompileEnvironment .current ()
1123- env .config_spec .load_eviction_policies = ListOf (
1124- EnumFragment (choices = VALID_EVICTION_POLICIES ), length = load_count
1125- )
1126- env .config_spec .indexing = ListOf (
1127- EnumFragment (choices = ConfigSpec ._valid_indexing_types ()), length = load_count
1128- )
1129- env .device_load_count = load_count
1146+
1147+ # Register eviction policies only for loads without explicit eviction_policy
1148+ if loads_without_eviction_policy > 0 :
1149+ env .config_spec .load_eviction_policies = ListOf (
1150+ EnumFragment (choices = VALID_EVICTION_POLICIES ),
1151+ length = loads_without_eviction_policy ,
1152+ )
1153+ env .device_load_count = loads_without_eviction_policy
1154+
1155+ # Indexing applies to ALL loads and stores
1156+ total_count = total_load_count + store_count
1157+ if total_count > 0 :
1158+ env .config_spec .indexing = ListOf (
1159+ EnumFragment (choices = ConfigSpec ._valid_indexing_types ()), length = total_count
1160+ )
11301161
11311162
11321163def lower_to_device_ir (func : HostFunction ) -> DeviceIR :
@@ -1151,9 +1182,13 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
11511182 # xyz not supported with shared program IDs, but persistent kernels are allowed
11521183 CompileEnvironment .current ().config_spec .disallow_pid_type ("xyz" )
11531184
1154- # Count all device loads and register tunables
1155- load_count = _count_device_loads (device_ir )
1156- _register_load_tunables (load_count )
1185+ # Count all device loads and stores and register tunables
1186+ total_load_count , loads_without_eviction_policy , store_count = (
1187+ _count_device_loads_and_stores (device_ir )
1188+ )
1189+ _register_load_store_tunables (
1190+ total_load_count , loads_without_eviction_policy , store_count
1191+ )
11571192
11581193 return device_ir
11591194
0 commit comments