-
Notifications
You must be signed in to change notification settings - Fork 82
Expand file tree
/
Copy pathwrite_analysis.py
More file actions
89 lines (66 loc) · 2.84 KB
/
Copy pathwrite_analysis.py
File metadata and controls
89 lines (66 loc) · 2.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Analysis for single-writer memory operations."""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum, auto
from typing import Optional, TYPE_CHECKING
import sympy
if TYPE_CHECKING:
from ..lang.tkw_types import IndexMapping
from .._support.indexing import IndexExpr
class AnalysisOutcome(Enum):
"""Result of analyzing a write operation for single-writer property."""
PROVEN_UNIQUE = auto()
OWNER_PREDICATE = auto()
NEEDS_GUARD = auto()
@dataclass(frozen=True, slots=True)
class OwnerPredicate:
"""Structured owner predicate: grid[axis] == value."""
axis: int
value: int = 0
@dataclass(slots=True)
class AnalysisResult:
"""Analysis result with outcome and optional owner predicate."""
outcome: AnalysisOutcome
predicate: Optional[OwnerPredicate] = None
@classmethod
def unique(cls) -> AnalysisResult:
return cls(AnalysisOutcome.PROVEN_UNIQUE)
@classmethod
def owner(cls, axis: int, value: int = 0) -> AnalysisResult:
return cls(AnalysisOutcome.OWNER_PREDICATE, OwnerPredicate(axis, value))
@classmethod
def guard(cls) -> AnalysisResult:
return cls(AnalysisOutcome.NEEDS_GUARD)
def _is_constant(expr) -> bool:
"""Check if expression is a constant value."""
return isinstance(expr, (int, sympy.Integer)) or getattr(expr, 'is_number', False)
def analyze_write(
mapping: Optional[IndexMapping],
ref_shape: tuple[IndexExpr, ...],
has_identity: bool = False,
) -> AnalysisResult:
"""Analyze write for single-writer property."""
# Fast path: identity mapping = each thread writes unique location
if has_identity or mapping is None:
return AnalysisResult.unique()
# Check identity via mapping API (reuses existing infrastructure)
if mapping.is_identity():
return AnalysisResult.unique()
# Attempt owner predicate extraction for reduction patterns
pred = _extract_owner_predicate(mapping)
return AnalysisResult.owner(pred.axis, pred.value) if pred else AnalysisResult.guard()
def _extract_owner_predicate(mapping: IndexMapping) -> Optional[OwnerPredicate]:
"""Extract canonical owner predicate from non-identity mapping."""
# Pattern: All output dimensions are constants (broadcast/reduction)
if all(_is_constant(expr) for expr in mapping.output_mapping.values()):
return OwnerPredicate(axis=0, value=0)
# Pattern: Floor division creates many-to-one mapping - needs full guard
for expr in mapping.output_mapping.values():
if isinstance(expr, sympy.Expr) and expr.has(sympy.floor):
return None
return None