Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions iree/turbine/kernel/compiler/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

NDEBUG = False

from dataclasses import dataclass


class CodegenError(Exception): ...


class ValidationError(CodegenError): ...


@dataclass
class CodegenOptions:
"""Configuration options for kernel code generation."""
enable_single_writer_guards: bool = True
guard_diagnostic_level: int = 0


options = CodegenOptions()
29 changes: 21 additions & 8 deletions iree/turbine/kernel/compiler/vector_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,13 @@
CodegenError,
NDEBUG,
ValidationError,
options,
)

from .write_analysis import analyze_write, AnalysisResult
from .write_codegen import emit_guarded_store


from .ir import (
AffineMap,
Attribute,
Expand Down Expand Up @@ -416,14 +421,22 @@ def _(emitter: ThreadEmitter, node: fx.Node):
insert_rank = 1

permutation_map = AffineMap.get_identity(dest_rank)
vector_d.transfer_write(
None,
insert_vector,
kb_dest,
start_indices,
AffineMapAttr.get(permutation_map),
in_bounds=[True for _ in range(insert_rank)],
)

# Analyze write for single-writer property (identity mapping = unique)
analysis = analyze_write(None, ref_shape)

def emit_store():
vector_d.transfer_write(
None,
insert_vector,
kb_dest,
start_indices,
AffineMapAttr.get(permutation_map),
in_bounds=[True for _ in range(insert_rank)],
)

emit_guarded_store(emitter, analysis, emit_store)



###############################################################################
Expand Down
89 changes: 89 additions & 0 deletions iree/turbine/kernel/compiler/write_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Analysis for single-writer memory operations."""

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum, auto
from typing import Optional, TYPE_CHECKING
import sympy

if TYPE_CHECKING:
from ..lang.tkw_types import IndexMapping
from .._support.indexing import IndexExpr


class AnalysisOutcome(Enum):
"""Result of analyzing a write operation for single-writer property."""
PROVEN_UNIQUE = auto()
OWNER_PREDICATE = auto()
NEEDS_GUARD = auto()


@dataclass(frozen=True, slots=True)
class OwnerPredicate:
"""Structured owner predicate: grid[axis] == value."""
axis: int
value: int = 0


@dataclass(slots=True)
class AnalysisResult:
"""Analysis result with outcome and optional owner predicate."""
outcome: AnalysisOutcome
predicate: Optional[OwnerPredicate] = None

@classmethod
def unique(cls) -> AnalysisResult:
return cls(AnalysisOutcome.PROVEN_UNIQUE)

@classmethod
def owner(cls, axis: int, value: int = 0) -> AnalysisResult:
return cls(AnalysisOutcome.OWNER_PREDICATE, OwnerPredicate(axis, value))

@classmethod
def guard(cls) -> AnalysisResult:
return cls(AnalysisOutcome.NEEDS_GUARD)


def _is_constant(expr) -> bool:
"""Check if expression is a constant value."""
return isinstance(expr, (int, sympy.Integer)) or getattr(expr, 'is_number', False)


def analyze_write(
mapping: Optional[IndexMapping],
ref_shape: tuple[IndexExpr, ...],
has_identity: bool = False,
) -> AnalysisResult:
"""Analyze write for single-writer property."""
# Fast path: identity mapping = each thread writes unique location
if has_identity or mapping is None:
return AnalysisResult.unique()

# Check identity via mapping API (reuses existing infrastructure)
if mapping.is_identity():
return AnalysisResult.unique()

# Attempt owner predicate extraction for reduction patterns
pred = _extract_owner_predicate(mapping)
return AnalysisResult.owner(pred.axis, pred.value) if pred else AnalysisResult.guard()


def _extract_owner_predicate(mapping: IndexMapping) -> Optional[OwnerPredicate]:
"""Extract canonical owner predicate from non-identity mapping."""
# Pattern: All output dimensions are constants (broadcast/reduction)
if all(_is_constant(expr) for expr in mapping.output_mapping.values()):
return OwnerPredicate(axis=0, value=0)

# Pattern: Floor division creates many-to-one mapping - needs full guard
for expr in mapping.output_mapping.values():
if isinstance(expr, sympy.Expr) and expr.has(sympy.floor):
return None

return None
93 changes: 93 additions & 0 deletions iree/turbine/kernel/compiler/write_codegen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Code generation for guarded memory writes."""

from __future__ import annotations

from typing import Callable, TYPE_CHECKING

from .ir import (
IndexType,
IntegerType,
MemRefType,
arith_d,
memref_d,
scf_d,
InsertionPoint,
Attribute,
)
from .write_analysis import AnalysisResult, AnalysisOutcome, OwnerPredicate
from .base import options

if TYPE_CHECKING:
from .vector_codegen import ThreadEmitter


def emit_guarded_store(
emitter: ThreadEmitter,
analysis: AnalysisResult,
store_fn: Callable[[], None],
) -> None:
"""Emit store with appropriate guard based on analysis result."""
if not options.enable_single_writer_guards:
store_fn()
return

match analysis.outcome:
case AnalysisOutcome.PROVEN_UNIQUE:
store_fn()
case AnalysisOutcome.OWNER_PREDICATE:
_emit_owner_guard(emitter, analysis.predicate, store_fn)
case AnalysisOutcome.NEEDS_GUARD:
_emit_atomic_guard(emitter, store_fn)


def _emit_owner_guard(
emitter: ThreadEmitter,
predicate: OwnerPredicate,
store_fn: Callable[[], None],
) -> None:
"""Emit scf.if guard: execute store only if predicate holds."""
axis_val = emitter.lookup_grid_axis_value(predicate.axis).ir_value
const_val = arith_d.constant(IndexType.get(), predicate.value)
cond = arith_d.cmpi(arith_d.CmpIPredicate.eq, axis_val, const_val)

if_op = scf_d.IfOp(cond, results_=[])
with InsertionPoint(if_op.then_block):
store_fn()
scf_d.yield_([])


def _emit_atomic_guard(
emitter: ThreadEmitter,
store_fn: Callable[[], None],
) -> None:
"""Emit atomic test-and-set guard for first-writer-wins semantics.

Uses atomic_rmw to ensure only one thread (the first to arrive) executes store.
The flag must be pre-allocated in workgroup memory and initialized before kernel.
"""
i32 = IntegerType.get_signless(32)
idx_ty = IndexType.get()

# Workgroup-local flag (address space 3) - must be pre-initialized
flag_type = MemRefType.get([1], i32, memory_space=Attribute.parse("3"))
flag = memref_d.alloca(flag_type, [], [])

c0 = arith_d.constant(idx_ty, 0)
c1_i32 = arith_d.constant(i32, 1)
c0_i32 = arith_d.constant(i32, 0)

# Atomic add 1, returns old value - first thread gets 0
old_val = memref_d.atomic_rmw(arith_d.AtomicRMWKind.addi, c1_i32, flag, [c0])
is_first = arith_d.cmpi(arith_d.CmpIPredicate.eq, old_val, c0_i32)

if_op = scf_d.IfOp(is_first, results_=[])
with InsertionPoint(if_op.then_block):
store_fn()
scf_d.yield_([])

90 changes: 90 additions & 0 deletions tests/kernel/compiler/write_analysis_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging
import unittest
import sympy

from iree.turbine.kernel.lang import sym
from iree.turbine.kernel.lang.tkw_types import IndexMapping
from iree.turbine.kernel.compiler.write_analysis import (
analyze_write,
AnalysisOutcome,
AnalysisResult,
OwnerPredicate,
)

M = sym.M
N = sym.N
K = sym.K


class WriteAnalysisTest(unittest.TestCase):
def testNoneMappingIsUnique(self):
"""None mapping (identity) should be PROVEN_UNIQUE."""
result = analyze_write(None, (M, N))
self.assertEqual(result.outcome, AnalysisOutcome.PROVEN_UNIQUE)

def testIdentityMappingIsUnique(self):
"""Identity IndexMapping should be PROVEN_UNIQUE."""
i0 = IndexMapping.iterator(0)
i1 = IndexMapping.iterator(1)
mapping = IndexMapping(2, {M: i0, N: i1}, {M: i0, N: i1})
result = analyze_write(mapping, (M, N))
self.assertEqual(result.outcome, AnalysisOutcome.PROVEN_UNIQUE)

def testHasIdentityFlagIsUnique(self):
"""has_identity=True should bypass analysis and return PROVEN_UNIQUE."""
i0 = IndexMapping.iterator(0)
# Non-identity mapping (broadcast to constant)
mapping = IndexMapping(1, {M: i0}, {N: 0})
result = analyze_write(mapping, (N,), has_identity=True)
self.assertEqual(result.outcome, AnalysisOutcome.PROVEN_UNIQUE)

def testBroadcastIsOwnerPredicate(self):
"""Broadcast (constant output) should be OWNER_PREDICATE."""
i0 = IndexMapping.iterator(0)
# All outputs are constants - broadcast pattern
mapping = IndexMapping(1, {M: i0}, {N: 0})
result = analyze_write(mapping, (N,))
self.assertEqual(result.outcome, AnalysisOutcome.OWNER_PREDICATE)
self.assertEqual(result.predicate.axis, 0)

def testFloorDivNeedsGuard(self):
"""Floor division pattern should be NEEDS_GUARD."""
i0 = IndexMapping.iterator(0)
# Floor division creates many-to-one mapping
mapping = IndexMapping(1, {M: i0}, {N: sympy.floor(i0 / 2)})
result = analyze_write(mapping, (N,))
self.assertEqual(result.outcome, AnalysisOutcome.NEEDS_GUARD)


class OwnerPredicateTest(unittest.TestCase):
def testFrozen(self):
"""OwnerPredicate should be immutable."""
pred = OwnerPredicate(axis=0, value=0)
with self.assertRaises(AttributeError):
pred.axis = 1


class AnalysisResultTest(unittest.TestCase):
def testUniqueFactory(self):
result = AnalysisResult.unique()
self.assertEqual(result.outcome, AnalysisOutcome.PROVEN_UNIQUE)

def testOwnerFactory(self):
result = AnalysisResult.owner(axis=1, value=0)
self.assertEqual(result.outcome, AnalysisOutcome.OWNER_PREDICATE)
self.assertEqual(result.predicate.axis, 1)

def testGuardFactory(self):
result = AnalysisResult.guard()
self.assertEqual(result.outcome, AnalysisOutcome.NEEDS_GUARD)


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
46 changes: 46 additions & 0 deletions tests/kernel/compiler/write_codegen_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Tests for guarded store dispatch logic.

Note: Full IR emission tests require ThreadEmitter context and are covered
by vector_codegen integration tests. These tests verify dispatch behavior.
"""

import logging
import unittest

from iree.turbine.kernel.compiler.write_analysis import AnalysisResult
from iree.turbine.kernel.compiler.write_codegen import emit_guarded_store
from iree.turbine.kernel.compiler import base


class EmitGuardedStoreTest(unittest.TestCase):
def setUp(self):
self._original = base.options.enable_single_writer_guards

def tearDown(self):
base.options.enable_single_writer_guards = self._original

def testProvenUniqueCallsStoreDirectly(self):
"""PROVEN_UNIQUE should call store function without guards."""
base.options.enable_single_writer_guards = True
called = []
emit_guarded_store(None, AnalysisResult.unique(), lambda: called.append(1))
self.assertEqual(len(called), 1)

def testGuardsDisabledAlwaysCallsStore(self):
"""When guards disabled, all outcomes call store directly."""
base.options.enable_single_writer_guards = False
called = []
emit_guarded_store(None, AnalysisResult.unique(), lambda: called.append(1))
emit_guarded_store(None, AnalysisResult.guard(), lambda: called.append(1))
self.assertEqual(len(called), 2)


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()