From 2627fadb471e7c9f2d8d4f61704c083c389d5665 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20C=C3=A1rdenas?= Date: Thu, 29 Jan 2026 01:05:25 -0500 Subject: [PATCH] [Turbine] Add single-writer write analysis and guarded store codegen MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Implement to determine single-writer guarantees (unique, owner predicate, or needs guard). - Add guarded store emission supporting owner predicate and atomic first-writer guard. - Add unit tests for analysis and guarded-store dispatch. - Expose option in codegen base options. Signed-off-by: Miguel Cárdenas --- iree/turbine/kernel/compiler/base.py | 18 ++++ .../turbine/kernel/compiler/vector_codegen.py | 29 ++++-- .../turbine/kernel/compiler/write_analysis.py | 89 ++++++++++++++++++ iree/turbine/kernel/compiler/write_codegen.py | 93 +++++++++++++++++++ tests/kernel/compiler/write_analysis_test.py | 90 ++++++++++++++++++ tests/kernel/compiler/write_codegen_test.py | 46 +++++++++ 6 files changed, 357 insertions(+), 8 deletions(-) create mode 100644 iree/turbine/kernel/compiler/write_analysis.py create mode 100644 iree/turbine/kernel/compiler/write_codegen.py create mode 100644 tests/kernel/compiler/write_analysis_test.py create mode 100644 tests/kernel/compiler/write_codegen_test.py diff --git a/iree/turbine/kernel/compiler/base.py b/iree/turbine/kernel/compiler/base.py index 197c842ba..27c228909 100644 --- a/iree/turbine/kernel/compiler/base.py +++ b/iree/turbine/kernel/compiler/base.py @@ -1,7 +1,25 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + NDEBUG = False +from dataclasses import dataclass + class CodegenError(Exception): ... class ValidationError(CodegenError): ... + + +@dataclass +class CodegenOptions: + """Configuration options for kernel code generation.""" + enable_single_writer_guards: bool = True + guard_diagnostic_level: int = 0 + + +options = CodegenOptions() diff --git a/iree/turbine/kernel/compiler/vector_codegen.py b/iree/turbine/kernel/compiler/vector_codegen.py index 50a8dfc90..b35e15b3f 100644 --- a/iree/turbine/kernel/compiler/vector_codegen.py +++ b/iree/turbine/kernel/compiler/vector_codegen.py @@ -45,8 +45,13 @@ CodegenError, NDEBUG, ValidationError, + options, ) +from .write_analysis import analyze_write, AnalysisResult +from .write_codegen import emit_guarded_store + + from .ir import ( AffineMap, Attribute, @@ -416,14 +421,22 @@ def _(emitter: ThreadEmitter, node: fx.Node): insert_rank = 1 permutation_map = AffineMap.get_identity(dest_rank) - vector_d.transfer_write( - None, - insert_vector, - kb_dest, - start_indices, - AffineMapAttr.get(permutation_map), - in_bounds=[True for _ in range(insert_rank)], - ) + + # Analyze write for single-writer property (identity mapping = unique) + analysis = analyze_write(None, ref_shape) + + def emit_store(): + vector_d.transfer_write( + None, + insert_vector, + kb_dest, + start_indices, + AffineMapAttr.get(permutation_map), + in_bounds=[True for _ in range(insert_rank)], + ) + + emit_guarded_store(emitter, analysis, emit_store) + ############################################################################### diff --git a/iree/turbine/kernel/compiler/write_analysis.py b/iree/turbine/kernel/compiler/write_analysis.py new file mode 100644 index 000000000..f18c51fbe --- /dev/null +++ b/iree/turbine/kernel/compiler/write_analysis.py @@ -0,0 +1,89 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Analysis for single-writer memory operations.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum, auto +from typing import Optional, TYPE_CHECKING +import sympy + +if TYPE_CHECKING: + from ..lang.tkw_types import IndexMapping + from .._support.indexing import IndexExpr + + +class AnalysisOutcome(Enum): + """Result of analyzing a write operation for single-writer property.""" + PROVEN_UNIQUE = auto() + OWNER_PREDICATE = auto() + NEEDS_GUARD = auto() + + +@dataclass(frozen=True, slots=True) +class OwnerPredicate: + """Structured owner predicate: grid[axis] == value.""" + axis: int + value: int = 0 + + +@dataclass(slots=True) +class AnalysisResult: + """Analysis result with outcome and optional owner predicate.""" + outcome: AnalysisOutcome + predicate: Optional[OwnerPredicate] = None + + @classmethod + def unique(cls) -> AnalysisResult: + return cls(AnalysisOutcome.PROVEN_UNIQUE) + + @classmethod + def owner(cls, axis: int, value: int = 0) -> AnalysisResult: + return cls(AnalysisOutcome.OWNER_PREDICATE, OwnerPredicate(axis, value)) + + @classmethod + def guard(cls) -> AnalysisResult: + return cls(AnalysisOutcome.NEEDS_GUARD) + + +def _is_constant(expr) -> bool: + """Check if expression is a constant value.""" + return isinstance(expr, (int, sympy.Integer)) or getattr(expr, 'is_number', False) + + +def analyze_write( + mapping: Optional[IndexMapping], + ref_shape: tuple[IndexExpr, ...], + has_identity: bool = False, +) -> AnalysisResult: + """Analyze write for single-writer property.""" + # Fast path: identity mapping = each thread writes unique location + if has_identity or mapping is None: + return AnalysisResult.unique() + + # Check identity via mapping API (reuses existing infrastructure) + if mapping.is_identity(): + return AnalysisResult.unique() + + # Attempt owner predicate extraction for reduction patterns + pred = _extract_owner_predicate(mapping) + return AnalysisResult.owner(pred.axis, pred.value) if pred else AnalysisResult.guard() + + +def _extract_owner_predicate(mapping: IndexMapping) -> Optional[OwnerPredicate]: + """Extract canonical owner predicate from non-identity mapping.""" + # Pattern: All output dimensions are constants (broadcast/reduction) + if all(_is_constant(expr) for expr in mapping.output_mapping.values()): + return OwnerPredicate(axis=0, value=0) + + # Pattern: Floor division creates many-to-one mapping - needs full guard + for expr in mapping.output_mapping.values(): + if isinstance(expr, sympy.Expr) and expr.has(sympy.floor): + return None + + return None diff --git a/iree/turbine/kernel/compiler/write_codegen.py b/iree/turbine/kernel/compiler/write_codegen.py new file mode 100644 index 000000000..c751374c4 --- /dev/null +++ b/iree/turbine/kernel/compiler/write_codegen.py @@ -0,0 +1,93 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Code generation for guarded memory writes.""" + +from __future__ import annotations + +from typing import Callable, TYPE_CHECKING + +from .ir import ( + IndexType, + IntegerType, + MemRefType, + arith_d, + memref_d, + scf_d, + InsertionPoint, + Attribute, +) +from .write_analysis import AnalysisResult, AnalysisOutcome, OwnerPredicate +from .base import options + +if TYPE_CHECKING: + from .vector_codegen import ThreadEmitter + + +def emit_guarded_store( + emitter: ThreadEmitter, + analysis: AnalysisResult, + store_fn: Callable[[], None], +) -> None: + """Emit store with appropriate guard based on analysis result.""" + if not options.enable_single_writer_guards: + store_fn() + return + + match analysis.outcome: + case AnalysisOutcome.PROVEN_UNIQUE: + store_fn() + case AnalysisOutcome.OWNER_PREDICATE: + _emit_owner_guard(emitter, analysis.predicate, store_fn) + case AnalysisOutcome.NEEDS_GUARD: + _emit_atomic_guard(emitter, store_fn) + + +def _emit_owner_guard( + emitter: ThreadEmitter, + predicate: OwnerPredicate, + store_fn: Callable[[], None], +) -> None: + """Emit scf.if guard: execute store only if predicate holds.""" + axis_val = emitter.lookup_grid_axis_value(predicate.axis).ir_value + const_val = arith_d.constant(IndexType.get(), predicate.value) + cond = arith_d.cmpi(arith_d.CmpIPredicate.eq, axis_val, const_val) + + if_op = scf_d.IfOp(cond, results_=[]) + with InsertionPoint(if_op.then_block): + store_fn() + scf_d.yield_([]) + + +def _emit_atomic_guard( + emitter: ThreadEmitter, + store_fn: Callable[[], None], +) -> None: + """Emit atomic test-and-set guard for first-writer-wins semantics. + + Uses atomic_rmw to ensure only one thread (the first to arrive) executes store. + The flag must be pre-allocated in workgroup memory and initialized before kernel. + """ + i32 = IntegerType.get_signless(32) + idx_ty = IndexType.get() + + # Workgroup-local flag (address space 3) - must be pre-initialized + flag_type = MemRefType.get([1], i32, memory_space=Attribute.parse("3")) + flag = memref_d.alloca(flag_type, [], []) + + c0 = arith_d.constant(idx_ty, 0) + c1_i32 = arith_d.constant(i32, 1) + c0_i32 = arith_d.constant(i32, 0) + + # Atomic add 1, returns old value - first thread gets 0 + old_val = memref_d.atomic_rmw(arith_d.AtomicRMWKind.addi, c1_i32, flag, [c0]) + is_first = arith_d.cmpi(arith_d.CmpIPredicate.eq, old_val, c0_i32) + + if_op = scf_d.IfOp(is_first, results_=[]) + with InsertionPoint(if_op.then_block): + store_fn() + scf_d.yield_([]) + diff --git a/tests/kernel/compiler/write_analysis_test.py b/tests/kernel/compiler/write_analysis_test.py new file mode 100644 index 000000000..424bec1e2 --- /dev/null +++ b/tests/kernel/compiler/write_analysis_test.py @@ -0,0 +1,90 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import unittest +import sympy + +from iree.turbine.kernel.lang import sym +from iree.turbine.kernel.lang.tkw_types import IndexMapping +from iree.turbine.kernel.compiler.write_analysis import ( + analyze_write, + AnalysisOutcome, + AnalysisResult, + OwnerPredicate, +) + +M = sym.M +N = sym.N +K = sym.K + + +class WriteAnalysisTest(unittest.TestCase): + def testNoneMappingIsUnique(self): + """None mapping (identity) should be PROVEN_UNIQUE.""" + result = analyze_write(None, (M, N)) + self.assertEqual(result.outcome, AnalysisOutcome.PROVEN_UNIQUE) + + def testIdentityMappingIsUnique(self): + """Identity IndexMapping should be PROVEN_UNIQUE.""" + i0 = IndexMapping.iterator(0) + i1 = IndexMapping.iterator(1) + mapping = IndexMapping(2, {M: i0, N: i1}, {M: i0, N: i1}) + result = analyze_write(mapping, (M, N)) + self.assertEqual(result.outcome, AnalysisOutcome.PROVEN_UNIQUE) + + def testHasIdentityFlagIsUnique(self): + """has_identity=True should bypass analysis and return PROVEN_UNIQUE.""" + i0 = IndexMapping.iterator(0) + # Non-identity mapping (broadcast to constant) + mapping = IndexMapping(1, {M: i0}, {N: 0}) + result = analyze_write(mapping, (N,), has_identity=True) + self.assertEqual(result.outcome, AnalysisOutcome.PROVEN_UNIQUE) + + def testBroadcastIsOwnerPredicate(self): + """Broadcast (constant output) should be OWNER_PREDICATE.""" + i0 = IndexMapping.iterator(0) + # All outputs are constants - broadcast pattern + mapping = IndexMapping(1, {M: i0}, {N: 0}) + result = analyze_write(mapping, (N,)) + self.assertEqual(result.outcome, AnalysisOutcome.OWNER_PREDICATE) + self.assertEqual(result.predicate.axis, 0) + + def testFloorDivNeedsGuard(self): + """Floor division pattern should be NEEDS_GUARD.""" + i0 = IndexMapping.iterator(0) + # Floor division creates many-to-one mapping + mapping = IndexMapping(1, {M: i0}, {N: sympy.floor(i0 / 2)}) + result = analyze_write(mapping, (N,)) + self.assertEqual(result.outcome, AnalysisOutcome.NEEDS_GUARD) + + +class OwnerPredicateTest(unittest.TestCase): + def testFrozen(self): + """OwnerPredicate should be immutable.""" + pred = OwnerPredicate(axis=0, value=0) + with self.assertRaises(AttributeError): + pred.axis = 1 + + +class AnalysisResultTest(unittest.TestCase): + def testUniqueFactory(self): + result = AnalysisResult.unique() + self.assertEqual(result.outcome, AnalysisOutcome.PROVEN_UNIQUE) + + def testOwnerFactory(self): + result = AnalysisResult.owner(axis=1, value=0) + self.assertEqual(result.outcome, AnalysisOutcome.OWNER_PREDICATE) + self.assertEqual(result.predicate.axis, 1) + + def testGuardFactory(self): + result = AnalysisResult.guard() + self.assertEqual(result.outcome, AnalysisOutcome.NEEDS_GUARD) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/kernel/compiler/write_codegen_test.py b/tests/kernel/compiler/write_codegen_test.py new file mode 100644 index 000000000..78a1b7920 --- /dev/null +++ b/tests/kernel/compiler/write_codegen_test.py @@ -0,0 +1,46 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Tests for guarded store dispatch logic. + +Note: Full IR emission tests require ThreadEmitter context and are covered +by vector_codegen integration tests. These tests verify dispatch behavior. +""" + +import logging +import unittest + +from iree.turbine.kernel.compiler.write_analysis import AnalysisResult +from iree.turbine.kernel.compiler.write_codegen import emit_guarded_store +from iree.turbine.kernel.compiler import base + + +class EmitGuardedStoreTest(unittest.TestCase): + def setUp(self): + self._original = base.options.enable_single_writer_guards + + def tearDown(self): + base.options.enable_single_writer_guards = self._original + + def testProvenUniqueCallsStoreDirectly(self): + """PROVEN_UNIQUE should call store function without guards.""" + base.options.enable_single_writer_guards = True + called = [] + emit_guarded_store(None, AnalysisResult.unique(), lambda: called.append(1)) + self.assertEqual(len(called), 1) + + def testGuardsDisabledAlwaysCallsStore(self): + """When guards disabled, all outcomes call store directly.""" + base.options.enable_single_writer_guards = False + called = [] + emit_guarded_store(None, AnalysisResult.unique(), lambda: called.append(1)) + emit_guarded_store(None, AnalysisResult.guard(), lambda: called.append(1)) + self.assertEqual(len(called), 2) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main()