Skip to content

Commit c5dbbbe

Browse files
authored
Make store indexing also individually tunable (#1028)
1 parent 0efcf06 commit c5dbbbe

File tree

11 files changed

+393
-112
lines changed

11 files changed

+393
-112
lines changed

README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ portable between different hardware. Helion automates and autotunes over:
3535

3636
* Automatically calculates strides and indices.
3737
* Autotunes choices among various indexing methods (pointers, block pointers, TensorDescriptors).
38-
* Supports per-load indexing strategies for fine-grained memory access control.
38+
* Supports per-operation indexing strategies for fine-grained memory access control of loads and stores.
3939

4040
2. **Masking:**
4141

@@ -259,10 +259,11 @@ cache behavior. A value of `1` disables this optimization, while higher
259259
values specify the grouping size.
260260

261261
* **indexing** (`"pointer"`, `"tensor_descriptor"`, `"block_ptr"`, or a list of these):
262-
Specifies the memory indexing strategy for load operations. Can be:
263-
- A single strategy (applies to all loads): `indexing="block_ptr"`
264-
- A list of strategies (one per load operation): `indexing=["pointer", "block_ptr", "tensor_descriptor"]`
265-
- Empty/omitted (defaults to `"pointer"` for all loads)
262+
Specifies the memory indexing strategy for load and store operations. Can be:
263+
- A single strategy (applies to all loads and stores): `indexing="block_ptr"`
264+
- A list of strategies (one per load/store in execution order): `indexing=["pointer", "pointer", "block_ptr"]`
265+
- Empty/omitted (defaults to `"pointer"` for all operations)
266+
- When using a list, provide strategies in order: `[load1, load2, ..., store1, store2, ...]`
266267

267268
The `"tensor_descriptor"` option uses Tensor Memory Accelerators (TMAs) but
268269
requires a Hopper or newer GPU and the latest development version of Triton.

docs/api/config.md

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -109,31 +109,37 @@ Configs are typically discovered automatically through autotuning, but can also
109109
110110
.. autoattribute:: Config.indexing
111111
112-
Memory indexing strategy for load operations. Can be specified as:
112+
Memory indexing strategy for load and store operations. Can be specified as:
113113
114-
**Single strategy (applies to all loads - backward compatible):**
114+
**Single strategy (applies to all loads and stores - backward compatible):**
115115
116116
.. code-block:: python
117117
118-
indexing="block_ptr" # All loads use block pointers
118+
indexing="block_ptr" # All loads and stores use block pointers
119119
120-
**Per-load strategies (list, one per load operation):**
120+
**Per-operation strategies (list, one per load/store in execution order):**
121121
122122
.. code-block:: python
123123
124-
indexing=["pointer", "block_ptr", "tensor_descriptor"]
124+
# 2 loads + 1 store = 3 indexing strategies
125+
indexing=["pointer", "pointer", "block_ptr"] # loads use pointer, store uses block_ptr
125126
126-
**Empty/omitted (defaults to** ``"pointer"`` **for all loads):**
127+
**Empty/omitted (defaults to** ``"pointer"`` **for all operations):**
127128
128129
.. code-block:: python
129130
130-
# indexing not specified - all loads use pointer indexing
131+
# indexing not specified - all loads and stores use pointer indexing
131132
132133
**Valid strategies:**
133134
134135
- ``"pointer"``: Pointer-based indexing (default)
135136
- ``"tensor_descriptor"``: Tensor descriptor indexing (requires Hopper+ GPU)
136137
- ``"block_ptr"``: Block pointer indexing
138+
139+
.. note::
140+
When using a list, provide one strategy for each load and store operation in the order
141+
they appear in the kernel. The indexing list is ordered as:
142+
``[load1, load2, ..., loadN, store1, store2, ..., storeM]``
137143
```
138144

139145
### Memory and Caching
@@ -212,32 +218,30 @@ import torch
212218
import helion
213219
import helion.language as hl
214220

215-
# Single indexing strategy for all loads (backward compatible)
221+
# Single indexing strategy for all loads and stores (backward compatible)
216222
@helion.kernel(config={"indexing": "block_ptr"})
217223
def kernel_uniform_indexing(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
218224
out = torch.empty_like(x)
219225
for tile in hl.tile(x.size(0)):
220-
a = hl.load(x, [tile]) # Uses block_ptr
221-
b = hl.load(y, [tile]) # Uses block_ptr
222-
out[tile] = a + b
226+
a = hl.load(x, [tile]) # Load: uses block_ptr
227+
b = hl.load(y, [tile]) # Load: uses block_ptr
228+
out[tile] = a + b # Store: uses block_ptr
223229
return out
224230

225-
# Per-load indexing strategies for fine-grained control
231+
# Per-operation indexing strategies for fine-grained control
232+
# Indexing list is ordered: [load1, load2, ..., store1, store2, ...]
226233
@helion.kernel(
227234
config={
228235
"block_size": 16,
229-
"indexing": ["pointer", "block_ptr", "tensor_descriptor"],
236+
"indexing": ["pointer", "pointer", "block_ptr"], # 2 loads + 1 store
230237
}
231238
)
232-
def kernel_mixed_indexing(
233-
x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
234-
) -> torch.Tensor:
239+
def kernel_mixed_indexing(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
235240
out = torch.empty_like(x)
236241
for tile in hl.tile(x.size(0)):
237242
a = hl.load(x, [tile]) # First load: pointer indexing
238-
b = hl.load(y, [tile]) # Second load: block_ptr indexing
239-
c = hl.load(z, [tile]) # Third load: tensor_descriptor indexing
240-
out[tile] = a + b + c
243+
b = hl.load(y, [tile]) # Second load: pointer indexing
244+
out[tile] = a + b # Store: block_ptr indexing
241245
return out
242246
```
243247

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ portable between different hardware. Helion automates and autotunes over:
3535

3636
* Automatically calculates strides and indices.
3737
* Autotunes choices among various indexing methods (pointers, block pointers, TensorDescriptors).
38-
* Supports per-load indexing strategies for fine-grained memory access control.
38+
* Supports per-operation indexing strategies for fine-grained memory access control of loads and stores.
3939

4040
2. **Masking:**
4141

helion/_compiler/device_function.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -247,40 +247,40 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
247247

248248
self.tile_strategy: TileStrategyDispatch = TileStrategyDispatch(self, config)
249249

250-
# Store indexing config to lazily create strategies per load
250+
# Store indexing config to lazily create strategies per load/store
251251
self._indexing_config = config.indexing
252252
self.indexing_strategies: list[IndexingStrategy] = []
253-
self.tensor_to_load_index: dict[
254-
int, int
255-
] = {} # Maps tensor id to its load index
256253

257254
self.rng_seed_count = 0
258255
self.device_load_index = 0
256+
self.device_store_index = 0
257+
# Single counter for both loads and stores for indexing assignment
258+
self.device_memory_op_index = 0
259259
self.rng_seed_buffer_param_name = None
260260

261-
def get_indexing_strategy(self, load_index: int) -> IndexingStrategy:
261+
def get_indexing_strategy(self, index: int) -> IndexingStrategy:
262262
from typing import cast
263263

264264
from .indexing_strategy import IndexingStrategy
265265
from .indexing_strategy import PointerIndexingStrategy
266266

267267
# Expand strategies list if needed
268-
while len(self.indexing_strategies) <= load_index:
268+
while len(self.indexing_strategies) <= index:
269269
idx = len(self.indexing_strategies)
270270

271271
if isinstance(self._indexing_config, str):
272-
# Single string: all loads use the same strategy
272+
# Single string: all loads/stores use the same strategy
273273
if not self.indexing_strategies:
274274
strategy = IndexingStrategy.select(
275275
cast("IndexingLiteral", self._indexing_config)
276276
)
277277
else:
278278
strategy = self.indexing_strategies[0]
279279
elif isinstance(self._indexing_config, list) and self._indexing_config:
280-
# List: one strategy per load
280+
# List: one strategy per load/store
281281
assert idx < len(self._indexing_config), (
282-
f"Load operation {idx} exceeds indexing config length "
283-
f"{len(self._indexing_config)}. Please specify indexing for all loads."
282+
f"Load/Store operation {idx} exceeds indexing config length "
283+
f"{len(self._indexing_config)}. Please specify indexing for all loads and stores."
284284
)
285285
strategy = IndexingStrategy.select(
286286
cast("IndexingLiteral", self._indexing_config[idx])
@@ -291,7 +291,7 @@ def get_indexing_strategy(self, load_index: int) -> IndexingStrategy:
291291

292292
self.indexing_strategies.append(strategy)
293293

294-
return self.indexing_strategies[load_index]
294+
return self.indexing_strategies[index]
295295

296296
def has_rng_ops(self) -> bool:
297297
"""Check if this kernel uses any RNG operations."""

helion/_compiler/device_ir.py

Lines changed: 65 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,8 +1076,15 @@ def visit_For(self, node: ast.For) -> None:
10761076
self.generic_visit(node)
10771077

10781078

1079-
def _count_device_loads(device_ir: DeviceIR) -> int:
1080-
"""Count the number of load operations in all device code for eviction policy tuning."""
1079+
def _count_device_loads_and_stores(device_ir: DeviceIR) -> tuple[int, int, int]:
1080+
"""Count the number of load and store operations in device code for autotuning.
1081+
1082+
Returns:
1083+
tuple[int, int, int]: (total_load_count, loads_without_eviction_policy, store_count)
1084+
- total_load_count: all loads (for indexing tunable)
1085+
- loads_without_eviction_policy: loads that need eviction policy tuning
1086+
- store_count: all stores (for indexing tunable)
1087+
"""
10811088
from ..language import memory_ops
10821089

10831090
# Build set of rolled graph IDs to exclude (these are duplicates)
@@ -1087,31 +1094,47 @@ def _count_device_loads(device_ir: DeviceIR) -> int:
10871094
if info.new_graph_id is not None
10881095
}
10891096

1090-
load_count = 0
1097+
total_load_count = 0
1098+
loads_without_eviction_policy = 0
1099+
store_count = 0
1100+
10911101
# Walk all graphs except rolled duplicates
10921102
for graph_info in device_ir.graphs:
10931103
if graph_info.graph_id in rolled_graph_ids:
10941104
continue
10951105

10961106
for node in graph_info.graph.nodes:
1097-
# Check if this is a load operation
1098-
if node.op == "call_function" and node.target is memory_ops.load:
1099-
# Only count loads without explicit eviction policy
1100-
# (user can still specify eviction_policy to override tuning)
1101-
# Check kwargs first, then check if 4th arg (eviction_policy) is None
1102-
eviction_policy_arg = node.kwargs.get("eviction_policy")
1103-
if eviction_policy_arg is None:
1104-
# Check if eviction_policy was passed as positional arg (index 3)
1105-
if len(node.args) >= 4:
1106-
eviction_policy_arg = node.args[3]
1107+
if node.op == "call_function":
1108+
# Check if this is a load operation
1109+
if node.target is memory_ops.load:
1110+
total_load_count += 1
1111+
# Check if this load needs eviction policy tuning
1112+
# (user can still specify eviction_policy to override tuning)
1113+
eviction_policy_arg = node.kwargs.get("eviction_policy")
11071114
if eviction_policy_arg is None:
1108-
load_count += 1
1109-
return load_count
1110-
1111-
1112-
def _register_load_tunables(load_count: int) -> None:
1113-
"""Register list-based tunables (indexing, eviction policies) for all device loads."""
1114-
if load_count == 0:
1115+
# Check if eviction_policy was passed as positional arg (index 3)
1116+
if len(node.args) >= 4:
1117+
eviction_policy_arg = node.args[3]
1118+
if eviction_policy_arg is None:
1119+
loads_without_eviction_policy += 1
1120+
# Check if this is a store operation
1121+
elif node.target is memory_ops.store:
1122+
store_count += 1
1123+
1124+
return total_load_count, loads_without_eviction_policy, store_count
1125+
1126+
1127+
def _register_load_store_tunables(
1128+
total_load_count: int, loads_without_eviction_policy: int, store_count: int
1129+
) -> None:
1130+
"""Register list-based tunables (indexing, eviction policies) for all device loads and stores.
1131+
1132+
Args:
1133+
total_load_count: Total number of loads (for indexing tunable)
1134+
loads_without_eviction_policy: Number of loads that need eviction policy tuning
1135+
store_count: Total number of stores (for indexing tunable)
1136+
"""
1137+
if total_load_count == 0 and store_count == 0:
11151138
return
11161139

11171140
from ..autotuner.config_fragment import EnumFragment
@@ -1120,13 +1143,21 @@ def _register_load_tunables(load_count: int) -> None:
11201143
from ..autotuner.config_spec import ConfigSpec
11211144

11221145
env = CompileEnvironment.current()
1123-
env.config_spec.load_eviction_policies = ListOf(
1124-
EnumFragment(choices=VALID_EVICTION_POLICIES), length=load_count
1125-
)
1126-
env.config_spec.indexing = ListOf(
1127-
EnumFragment(choices=ConfigSpec._valid_indexing_types()), length=load_count
1128-
)
1129-
env.device_load_count = load_count
1146+
1147+
# Register eviction policies only for loads without explicit eviction_policy
1148+
if loads_without_eviction_policy > 0:
1149+
env.config_spec.load_eviction_policies = ListOf(
1150+
EnumFragment(choices=VALID_EVICTION_POLICIES),
1151+
length=loads_without_eviction_policy,
1152+
)
1153+
env.device_load_count = loads_without_eviction_policy
1154+
1155+
# Indexing applies to ALL loads and stores
1156+
total_count = total_load_count + store_count
1157+
if total_count > 0:
1158+
env.config_spec.indexing = ListOf(
1159+
EnumFragment(choices=ConfigSpec._valid_indexing_types()), length=total_count
1160+
)
11301161

11311162

11321163
def lower_to_device_ir(func: HostFunction) -> DeviceIR:
@@ -1151,9 +1182,13 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
11511182
# xyz not supported with shared program IDs, but persistent kernels are allowed
11521183
CompileEnvironment.current().config_spec.disallow_pid_type("xyz")
11531184

1154-
# Count all device loads and register tunables
1155-
load_count = _count_device_loads(device_ir)
1156-
_register_load_tunables(load_count)
1185+
# Count all device loads and stores and register tunables
1186+
total_load_count, loads_without_eviction_policy, store_count = (
1187+
_count_device_loads_and_stores(device_ir)
1188+
)
1189+
_register_load_store_tunables(
1190+
total_load_count, loads_without_eviction_policy, store_count
1191+
)
11571192

11581193
return device_ir
11591194

helion/language/memory_ops.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,11 @@ def _(state: CodegenState) -> ast.AST:
9797

9898
if isinstance(tensor, torch.Tensor):
9999
device_fn = state.device_function
100-
# Use the same strategy that was used to load this tensor, or default to first strategy
101-
load_idx = device_fn.tensor_to_load_index.get(id(tensor), 0)
102-
strategy = device_fn.get_indexing_strategy(load_idx)
100+
device_fn.device_store_index += 1
101+
# Use the shared memory op index for indexing strategy
102+
indexing_idx = device_fn.device_memory_op_index
103+
device_fn.device_memory_op_index += 1
104+
strategy = device_fn.get_indexing_strategy(indexing_idx)
103105
return strategy.codegen_store(state, tensor, [*subscript], value, extra_mask)
104106
if isinstance(tensor, tuple):
105107
from .._compiler.indexing_strategy import StackIndexingStrategy
@@ -268,9 +270,10 @@ def _(state: CodegenState) -> ast.AST:
268270
eviction_policy = ast.Constant(value=eviction_policy)
269271

270272
if isinstance(tensor, torch.Tensor):
271-
strategy = device_fn.get_indexing_strategy(load_idx)
272-
# Track which strategy was used for this tensor so stores can use the same one
273-
device_fn.tensor_to_load_index[id(tensor)] = load_idx
273+
# Use the shared memory op index for indexing strategy
274+
indexing_idx = device_fn.device_memory_op_index
275+
device_fn.device_memory_op_index += 1
276+
strategy = device_fn.get_indexing_strategy(indexing_idx)
274277
return strategy.codegen_load(
275278
state, tensor, [*subscript], extra_mask, eviction_policy
276279
)

helion/runtime/config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ def __init__(
6161
num_warps: Number of warps per block.
6262
num_stages: Number of stages for software pipelining.
6363
pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved").
64-
indexing: Indexing strategy for load operations. Can be:
65-
- A single strategy string (all loads use this strategy):
64+
indexing: Indexing strategy for load and store operations. Can be:
65+
- A single strategy string (all loads/stores use this strategy):
6666
indexing="block_ptr" # backward compatible
67-
- A list of strategies (one per load operation, must specify all):
67+
- A list of strategies (one per load/store operation, must specify all):
6868
indexing=["pointer", "block_ptr", "tensor_descriptor"]
69-
- Empty/omitted (all loads default to "pointer")
69+
- Empty/omitted (all loads/stores default to "pointer")
7070
Valid strategies: "pointer", "tensor_descriptor", "block_ptr"
7171
**kwargs: Additional user-defined configuration parameters.
7272
"""

0 commit comments

Comments
 (0)