|
| 1 | +# Copyright 2024 The IREE Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +# See https://llvm.org/LICENSE.txt for license information. |
| 5 | +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | + |
| 7 | +import logging |
| 8 | +import unittest |
| 9 | +import sympy |
| 10 | + |
| 11 | +from iree.turbine.kernel.lang import sym |
| 12 | +from iree.turbine.kernel.lang.tkw_types import IndexMapping |
| 13 | +from iree.turbine.kernel.compiler.write_analysis import ( |
| 14 | + analyze_write, |
| 15 | + AnalysisOutcome, |
| 16 | + AnalysisResult, |
| 17 | + OwnerPredicate, |
| 18 | +) |
| 19 | + |
| 20 | +M = sym.M |
| 21 | +N = sym.N |
| 22 | +K = sym.K |
| 23 | + |
| 24 | + |
| 25 | +class WriteAnalysisTest(unittest.TestCase): |
| 26 | + def testNoneMappingIsUnique(self): |
| 27 | + """None mapping (identity) should be PROVEN_UNIQUE.""" |
| 28 | + result = analyze_write(None, (M, N)) |
| 29 | + self.assertEqual(result.outcome, AnalysisOutcome.PROVEN_UNIQUE) |
| 30 | + |
| 31 | + def testIdentityMappingIsUnique(self): |
| 32 | + """Identity IndexMapping should be PROVEN_UNIQUE.""" |
| 33 | + i0 = IndexMapping.iterator(0) |
| 34 | + i1 = IndexMapping.iterator(1) |
| 35 | + mapping = IndexMapping(2, {M: i0, N: i1}, {M: i0, N: i1}) |
| 36 | + result = analyze_write(mapping, (M, N)) |
| 37 | + self.assertEqual(result.outcome, AnalysisOutcome.PROVEN_UNIQUE) |
| 38 | + |
| 39 | + def testHasIdentityFlagIsUnique(self): |
| 40 | + """has_identity=True should bypass analysis and return PROVEN_UNIQUE.""" |
| 41 | + i0 = IndexMapping.iterator(0) |
| 42 | + # Non-identity mapping (broadcast to constant) |
| 43 | + mapping = IndexMapping(1, {M: i0}, {N: 0}) |
| 44 | + result = analyze_write(mapping, (N,), has_identity=True) |
| 45 | + self.assertEqual(result.outcome, AnalysisOutcome.PROVEN_UNIQUE) |
| 46 | + |
| 47 | + def testBroadcastIsOwnerPredicate(self): |
| 48 | + """Broadcast (constant output) should be OWNER_PREDICATE.""" |
| 49 | + i0 = IndexMapping.iterator(0) |
| 50 | + # All outputs are constants - broadcast pattern |
| 51 | + mapping = IndexMapping(1, {M: i0}, {N: 0}) |
| 52 | + result = analyze_write(mapping, (N,)) |
| 53 | + self.assertEqual(result.outcome, AnalysisOutcome.OWNER_PREDICATE) |
| 54 | + self.assertEqual(result.predicate.axis, 0) |
| 55 | + |
| 56 | + def testFloorDivNeedsGuard(self): |
| 57 | + """Floor division pattern should be NEEDS_GUARD.""" |
| 58 | + i0 = IndexMapping.iterator(0) |
| 59 | + # Floor division creates many-to-one mapping |
| 60 | + mapping = IndexMapping(1, {M: i0}, {N: sympy.floor(i0 / 2)}) |
| 61 | + result = analyze_write(mapping, (N,)) |
| 62 | + self.assertEqual(result.outcome, AnalysisOutcome.NEEDS_GUARD) |
| 63 | + |
| 64 | + |
| 65 | +class OwnerPredicateTest(unittest.TestCase): |
| 66 | + def testFrozen(self): |
| 67 | + """OwnerPredicate should be immutable.""" |
| 68 | + pred = OwnerPredicate(axis=0, value=0) |
| 69 | + with self.assertRaises(AttributeError): |
| 70 | + pred.axis = 1 |
| 71 | + |
| 72 | + |
| 73 | +class AnalysisResultTest(unittest.TestCase): |
| 74 | + def testUniqueFactory(self): |
| 75 | + result = AnalysisResult.unique() |
| 76 | + self.assertEqual(result.outcome, AnalysisOutcome.PROVEN_UNIQUE) |
| 77 | + |
| 78 | + def testOwnerFactory(self): |
| 79 | + result = AnalysisResult.owner(axis=1, value=0) |
| 80 | + self.assertEqual(result.outcome, AnalysisOutcome.OWNER_PREDICATE) |
| 81 | + self.assertEqual(result.predicate.axis, 1) |
| 82 | + |
| 83 | + def testGuardFactory(self): |
| 84 | + result = AnalysisResult.guard() |
| 85 | + self.assertEqual(result.outcome, AnalysisOutcome.NEEDS_GUARD) |
| 86 | + |
| 87 | + |
| 88 | +if __name__ == "__main__": |
| 89 | + logging.basicConfig(level=logging.DEBUG) |
| 90 | + unittest.main() |
0 commit comments