Skip to content

Commit 2627fad

Browse files
committed
[Turbine] Add single-writer write analysis and guarded store codegen
- Implement to determine single-writer guarantees (unique, owner predicate, or needs guard). - Add guarded store emission supporting owner predicate and atomic first-writer guard. - Add unit tests for analysis and guarded-store dispatch. - Expose option in codegen base options. Signed-off-by: Miguel Cárdenas <miguelecsx@gmail.com>
1 parent 137eba8 commit 2627fad

6 files changed

Lines changed: 357 additions & 8 deletions

File tree

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,25 @@
1+
# Copyright 2024 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
17
NDEBUG = False
28

9+
from dataclasses import dataclass
10+
311

412
class CodegenError(Exception): ...
513

614

715
class ValidationError(CodegenError): ...
16+
17+
18+
@dataclass
19+
class CodegenOptions:
20+
"""Configuration options for kernel code generation."""
21+
enable_single_writer_guards: bool = True
22+
guard_diagnostic_level: int = 0
23+
24+
25+
options = CodegenOptions()

iree/turbine/kernel/compiler/vector_codegen.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,13 @@
4545
CodegenError,
4646
NDEBUG,
4747
ValidationError,
48+
options,
4849
)
4950

51+
from .write_analysis import analyze_write, AnalysisResult
52+
from .write_codegen import emit_guarded_store
53+
54+
5055
from .ir import (
5156
AffineMap,
5257
Attribute,
@@ -416,14 +421,22 @@ def _(emitter: ThreadEmitter, node: fx.Node):
416421
insert_rank = 1
417422

418423
permutation_map = AffineMap.get_identity(dest_rank)
419-
vector_d.transfer_write(
420-
None,
421-
insert_vector,
422-
kb_dest,
423-
start_indices,
424-
AffineMapAttr.get(permutation_map),
425-
in_bounds=[True for _ in range(insert_rank)],
426-
)
424+
425+
# Analyze write for single-writer property (identity mapping = unique)
426+
analysis = analyze_write(None, ref_shape)
427+
428+
def emit_store():
429+
vector_d.transfer_write(
430+
None,
431+
insert_vector,
432+
kb_dest,
433+
start_indices,
434+
AffineMapAttr.get(permutation_map),
435+
in_bounds=[True for _ in range(insert_rank)],
436+
)
437+
438+
emit_guarded_store(emitter, analysis, emit_store)
439+
427440

428441

429442
###############################################################################
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2024 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
"""Analysis for single-writer memory operations."""
8+
9+
from __future__ import annotations
10+
11+
from dataclasses import dataclass
12+
from enum import Enum, auto
13+
from typing import Optional, TYPE_CHECKING
14+
import sympy
15+
16+
if TYPE_CHECKING:
17+
from ..lang.tkw_types import IndexMapping
18+
from .._support.indexing import IndexExpr
19+
20+
21+
class AnalysisOutcome(Enum):
22+
"""Result of analyzing a write operation for single-writer property."""
23+
PROVEN_UNIQUE = auto()
24+
OWNER_PREDICATE = auto()
25+
NEEDS_GUARD = auto()
26+
27+
28+
@dataclass(frozen=True, slots=True)
29+
class OwnerPredicate:
30+
"""Structured owner predicate: grid[axis] == value."""
31+
axis: int
32+
value: int = 0
33+
34+
35+
@dataclass(slots=True)
36+
class AnalysisResult:
37+
"""Analysis result with outcome and optional owner predicate."""
38+
outcome: AnalysisOutcome
39+
predicate: Optional[OwnerPredicate] = None
40+
41+
@classmethod
42+
def unique(cls) -> AnalysisResult:
43+
return cls(AnalysisOutcome.PROVEN_UNIQUE)
44+
45+
@classmethod
46+
def owner(cls, axis: int, value: int = 0) -> AnalysisResult:
47+
return cls(AnalysisOutcome.OWNER_PREDICATE, OwnerPredicate(axis, value))
48+
49+
@classmethod
50+
def guard(cls) -> AnalysisResult:
51+
return cls(AnalysisOutcome.NEEDS_GUARD)
52+
53+
54+
def _is_constant(expr) -> bool:
55+
"""Check if expression is a constant value."""
56+
return isinstance(expr, (int, sympy.Integer)) or getattr(expr, 'is_number', False)
57+
58+
59+
def analyze_write(
60+
mapping: Optional[IndexMapping],
61+
ref_shape: tuple[IndexExpr, ...],
62+
has_identity: bool = False,
63+
) -> AnalysisResult:
64+
"""Analyze write for single-writer property."""
65+
# Fast path: identity mapping = each thread writes unique location
66+
if has_identity or mapping is None:
67+
return AnalysisResult.unique()
68+
69+
# Check identity via mapping API (reuses existing infrastructure)
70+
if mapping.is_identity():
71+
return AnalysisResult.unique()
72+
73+
# Attempt owner predicate extraction for reduction patterns
74+
pred = _extract_owner_predicate(mapping)
75+
return AnalysisResult.owner(pred.axis, pred.value) if pred else AnalysisResult.guard()
76+
77+
78+
def _extract_owner_predicate(mapping: IndexMapping) -> Optional[OwnerPredicate]:
79+
"""Extract canonical owner predicate from non-identity mapping."""
80+
# Pattern: All output dimensions are constants (broadcast/reduction)
81+
if all(_is_constant(expr) for expr in mapping.output_mapping.values()):
82+
return OwnerPredicate(axis=0, value=0)
83+
84+
# Pattern: Floor division creates many-to-one mapping - needs full guard
85+
for expr in mapping.output_mapping.values():
86+
if isinstance(expr, sympy.Expr) and expr.has(sympy.floor):
87+
return None
88+
89+
return None
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2024 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
"""Code generation for guarded memory writes."""
8+
9+
from __future__ import annotations
10+
11+
from typing import Callable, TYPE_CHECKING
12+
13+
from .ir import (
14+
IndexType,
15+
IntegerType,
16+
MemRefType,
17+
arith_d,
18+
memref_d,
19+
scf_d,
20+
InsertionPoint,
21+
Attribute,
22+
)
23+
from .write_analysis import AnalysisResult, AnalysisOutcome, OwnerPredicate
24+
from .base import options
25+
26+
if TYPE_CHECKING:
27+
from .vector_codegen import ThreadEmitter
28+
29+
30+
def emit_guarded_store(
31+
emitter: ThreadEmitter,
32+
analysis: AnalysisResult,
33+
store_fn: Callable[[], None],
34+
) -> None:
35+
"""Emit store with appropriate guard based on analysis result."""
36+
if not options.enable_single_writer_guards:
37+
store_fn()
38+
return
39+
40+
match analysis.outcome:
41+
case AnalysisOutcome.PROVEN_UNIQUE:
42+
store_fn()
43+
case AnalysisOutcome.OWNER_PREDICATE:
44+
_emit_owner_guard(emitter, analysis.predicate, store_fn)
45+
case AnalysisOutcome.NEEDS_GUARD:
46+
_emit_atomic_guard(emitter, store_fn)
47+
48+
49+
def _emit_owner_guard(
50+
emitter: ThreadEmitter,
51+
predicate: OwnerPredicate,
52+
store_fn: Callable[[], None],
53+
) -> None:
54+
"""Emit scf.if guard: execute store only if predicate holds."""
55+
axis_val = emitter.lookup_grid_axis_value(predicate.axis).ir_value
56+
const_val = arith_d.constant(IndexType.get(), predicate.value)
57+
cond = arith_d.cmpi(arith_d.CmpIPredicate.eq, axis_val, const_val)
58+
59+
if_op = scf_d.IfOp(cond, results_=[])
60+
with InsertionPoint(if_op.then_block):
61+
store_fn()
62+
scf_d.yield_([])
63+
64+
65+
def _emit_atomic_guard(
66+
emitter: ThreadEmitter,
67+
store_fn: Callable[[], None],
68+
) -> None:
69+
"""Emit atomic test-and-set guard for first-writer-wins semantics.
70+
71+
Uses atomic_rmw to ensure only one thread (the first to arrive) executes store.
72+
The flag must be pre-allocated in workgroup memory and initialized before kernel.
73+
"""
74+
i32 = IntegerType.get_signless(32)
75+
idx_ty = IndexType.get()
76+
77+
# Workgroup-local flag (address space 3) - must be pre-initialized
78+
flag_type = MemRefType.get([1], i32, memory_space=Attribute.parse("3"))
79+
flag = memref_d.alloca(flag_type, [], [])
80+
81+
c0 = arith_d.constant(idx_ty, 0)
82+
c1_i32 = arith_d.constant(i32, 1)
83+
c0_i32 = arith_d.constant(i32, 0)
84+
85+
# Atomic add 1, returns old value - first thread gets 0
86+
old_val = memref_d.atomic_rmw(arith_d.AtomicRMWKind.addi, c1_i32, flag, [c0])
87+
is_first = arith_d.cmpi(arith_d.CmpIPredicate.eq, old_val, c0_i32)
88+
89+
if_op = scf_d.IfOp(is_first, results_=[])
90+
with InsertionPoint(if_op.then_block):
91+
store_fn()
92+
scf_d.yield_([])
93+
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2024 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import logging
8+
import unittest
9+
import sympy
10+
11+
from iree.turbine.kernel.lang import sym
12+
from iree.turbine.kernel.lang.tkw_types import IndexMapping
13+
from iree.turbine.kernel.compiler.write_analysis import (
14+
analyze_write,
15+
AnalysisOutcome,
16+
AnalysisResult,
17+
OwnerPredicate,
18+
)
19+
20+
M = sym.M
21+
N = sym.N
22+
K = sym.K
23+
24+
25+
class WriteAnalysisTest(unittest.TestCase):
26+
def testNoneMappingIsUnique(self):
27+
"""None mapping (identity) should be PROVEN_UNIQUE."""
28+
result = analyze_write(None, (M, N))
29+
self.assertEqual(result.outcome, AnalysisOutcome.PROVEN_UNIQUE)
30+
31+
def testIdentityMappingIsUnique(self):
32+
"""Identity IndexMapping should be PROVEN_UNIQUE."""
33+
i0 = IndexMapping.iterator(0)
34+
i1 = IndexMapping.iterator(1)
35+
mapping = IndexMapping(2, {M: i0, N: i1}, {M: i0, N: i1})
36+
result = analyze_write(mapping, (M, N))
37+
self.assertEqual(result.outcome, AnalysisOutcome.PROVEN_UNIQUE)
38+
39+
def testHasIdentityFlagIsUnique(self):
40+
"""has_identity=True should bypass analysis and return PROVEN_UNIQUE."""
41+
i0 = IndexMapping.iterator(0)
42+
# Non-identity mapping (broadcast to constant)
43+
mapping = IndexMapping(1, {M: i0}, {N: 0})
44+
result = analyze_write(mapping, (N,), has_identity=True)
45+
self.assertEqual(result.outcome, AnalysisOutcome.PROVEN_UNIQUE)
46+
47+
def testBroadcastIsOwnerPredicate(self):
48+
"""Broadcast (constant output) should be OWNER_PREDICATE."""
49+
i0 = IndexMapping.iterator(0)
50+
# All outputs are constants - broadcast pattern
51+
mapping = IndexMapping(1, {M: i0}, {N: 0})
52+
result = analyze_write(mapping, (N,))
53+
self.assertEqual(result.outcome, AnalysisOutcome.OWNER_PREDICATE)
54+
self.assertEqual(result.predicate.axis, 0)
55+
56+
def testFloorDivNeedsGuard(self):
57+
"""Floor division pattern should be NEEDS_GUARD."""
58+
i0 = IndexMapping.iterator(0)
59+
# Floor division creates many-to-one mapping
60+
mapping = IndexMapping(1, {M: i0}, {N: sympy.floor(i0 / 2)})
61+
result = analyze_write(mapping, (N,))
62+
self.assertEqual(result.outcome, AnalysisOutcome.NEEDS_GUARD)
63+
64+
65+
class OwnerPredicateTest(unittest.TestCase):
66+
def testFrozen(self):
67+
"""OwnerPredicate should be immutable."""
68+
pred = OwnerPredicate(axis=0, value=0)
69+
with self.assertRaises(AttributeError):
70+
pred.axis = 1
71+
72+
73+
class AnalysisResultTest(unittest.TestCase):
74+
def testUniqueFactory(self):
75+
result = AnalysisResult.unique()
76+
self.assertEqual(result.outcome, AnalysisOutcome.PROVEN_UNIQUE)
77+
78+
def testOwnerFactory(self):
79+
result = AnalysisResult.owner(axis=1, value=0)
80+
self.assertEqual(result.outcome, AnalysisOutcome.OWNER_PREDICATE)
81+
self.assertEqual(result.predicate.axis, 1)
82+
83+
def testGuardFactory(self):
84+
result = AnalysisResult.guard()
85+
self.assertEqual(result.outcome, AnalysisOutcome.NEEDS_GUARD)
86+
87+
88+
if __name__ == "__main__":
89+
logging.basicConfig(level=logging.DEBUG)
90+
unittest.main()
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2024 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
"""Tests for guarded store dispatch logic.
8+
9+
Note: Full IR emission tests require ThreadEmitter context and are covered
10+
by vector_codegen integration tests. These tests verify dispatch behavior.
11+
"""
12+
13+
import logging
14+
import unittest
15+
16+
from iree.turbine.kernel.compiler.write_analysis import AnalysisResult
17+
from iree.turbine.kernel.compiler.write_codegen import emit_guarded_store
18+
from iree.turbine.kernel.compiler import base
19+
20+
21+
class EmitGuardedStoreTest(unittest.TestCase):
22+
def setUp(self):
23+
self._original = base.options.enable_single_writer_guards
24+
25+
def tearDown(self):
26+
base.options.enable_single_writer_guards = self._original
27+
28+
def testProvenUniqueCallsStoreDirectly(self):
29+
"""PROVEN_UNIQUE should call store function without guards."""
30+
base.options.enable_single_writer_guards = True
31+
called = []
32+
emit_guarded_store(None, AnalysisResult.unique(), lambda: called.append(1))
33+
self.assertEqual(len(called), 1)
34+
35+
def testGuardsDisabledAlwaysCallsStore(self):
36+
"""When guards disabled, all outcomes call store directly."""
37+
base.options.enable_single_writer_guards = False
38+
called = []
39+
emit_guarded_store(None, AnalysisResult.unique(), lambda: called.append(1))
40+
emit_guarded_store(None, AnalysisResult.guard(), lambda: called.append(1))
41+
self.assertEqual(len(called), 2)
42+
43+
44+
if __name__ == "__main__":
45+
logging.basicConfig(level=logging.DEBUG)
46+
unittest.main()

0 commit comments

Comments
 (0)