Skip to content

Commit cfd9912

Browse files
HatsunespicaStefanPiscu
authored andcommitted
Add lowerings from smt tensor to low-level smt (opencompl#82)
1 parent c2a6aa0 commit cfd9912

File tree

7 files changed

+177
-5
lines changed

7 files changed

+177
-5
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: xdsl-smt "%s" -p=lower-smt-tensor -t=smt | filecheck "%s"
2+
3+
// Lower tensor from "smt.tensor" operations.
4+
5+
builtin.module {
6+
%0 = "smt.declare_const"() : () -> !smt.tensor.tensor<[3, 3], !smt.fp<8, 24>, none>
7+
// CHECK: (declare-const $tmp (Array (_ BitVec 64) (Array (_ BitVec 64) (_ FloatingPoint 8 24))
8+
9+
%idx1 = "smt.declare_const"() : () -> !smt.bv<64>
10+
%idx2 = "smt.declare_const"() : () -> !smt.bv<64>
11+
// CHECK-NEXT: (declare-const $idx1 (_ BitVec 64))
12+
// CHECK-NEXT: (declare-const $idx2 (_ BitVec 64))
13+
14+
%extract = "smt.tensor.extract"(%0, %idx1, %idx2):
15+
(!smt.tensor.tensor<[3, 3], !smt.fp<8, 24>, none>, !smt.bv<64>, !smt.bv<64>) -> !smt.fp<8, 24>
16+
%zero = "smt.fp.pzero"() : () -> !smt.fp<8,24>
17+
%eq_zero = "smt.eq"(%extract, %zero) : (!smt.fp<8,24>, !smt.fp<8,24>) -> !smt.bool
18+
"smt.assert"(%eq_zero) : (!smt.bool) -> ()
19+
// CHECK-NEXT: (assert (= (select (select $tmp $idx1) $idx2) (_ +zero 8 24)))
20+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: xdsl-smt "%s" -p=rewrite-smt-tensor | filecheck "%s"
2+
3+
// Rewrite tensor from "smt.extract" operations.
4+
5+
builtin.module {
6+
%0 = "smt.declare_const"() : () -> !smt.tensor.tensor<[3, 3], !smt.fp<8, 24>, none>
7+
%idx1 = "smt.declare_const"() : () -> !smt.bv<64>
8+
%idx2 = "smt.declare_const"() : () -> !smt.bv<64>
9+
%transpose = "smt.tensor.transpose"(%0){permutation = [1, 0]} :
10+
(!smt.tensor.tensor<[3, 3], !smt.fp<8, 24>, none>) -> !smt.tensor.tensor<[3, 3], !smt.fp<8, 24>, none>
11+
%extract = "smt.tensor.extract"(%transpose, %idx1, %idx2):
12+
(!smt.tensor.tensor<[3, 3], !smt.fp<8, 24>, none>, !smt.bv<64>, !smt.bv<64>) -> !smt.fp<8, 24>
13+
}
14+
15+
// CHECK: %0 = "smt.declare_const"() : () -> !smt.tensor.tensor<[3 : i64, 3 : i64], !smt.fp<8, 24>, none>
16+
// CHECK-NEXT: %idx1 = "smt.declare_const"() : () -> !smt.bv<64>
17+
// CHECK-NEXT: %idx2 = "smt.declare_const"() : () -> !smt.bv<64>
18+
// CHECK-NEXT: %extract = "smt.tensor.extract"(%0, %idx2, %idx1) : (!smt.tensor.tensor<[3 : i64, 3 : i64], !smt.fp<8, 24>, none>, !smt.bv<64>, !smt.bv<64>) -> !smt.fp<8, 24>

xdsl_smt/cli/xdsl_smt.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from xdsl_smt.passes.load_parametric_int_semantics import LoadIntSemanticsPass
2525
from xdsl_smt.passes.lower_memory_effects import LowerMemoryEffectsPass
2626
from xdsl_smt.passes.lower_effects_with_memory import LowerEffectsWithMemoryPass
27+
from xdsl_smt.passes.lower_smt_tensor import LowerSMTTensor
2728
from xdsl_smt.passes.merge_func_results import MergeFuncResultsPass
2829
from xdsl_smt.passes.lower_memory_to_array import LowerMemoryToArrayPass
2930
from xdsl_smt.passes.raise_llvm_to_func import RaiseLLVMToFunc
@@ -48,6 +49,7 @@
4849
from xdsl_smt.passes.lower_pairs import LowerPairs
4950
from xdsl_smt.passes.lower_to_smt import LowerToSMTPass
5051
from xdsl_smt.passes.lower_ub_to_pairs import LowerUBToPairs
52+
from xdsl_smt.passes.rewrite_smt_tensor import RewriteSMTTensor
5153
from xdsl_smt.passes.smt_expand import SMTExpand
5254
from xdsl_smt.passes.pdl_add_implicit_properties import PDLAddImplicitPropertiesPass
5355

@@ -109,6 +111,7 @@ def register_all_dialects(self):
109111
self.ctx.load_registered_dialect(SMTUtilsDialect.name)
110112
self.ctx.load_registered_dialect(SMTArray.name)
111113
self.ctx.load_registered_dialect(SMTFloatingPointDialect.name)
114+
self.ctx.load_registered_dialect(SMTTensorDialect.name)
112115

113116
def register_all_passes(self):
114117
super().register_all_passes()
@@ -132,6 +135,8 @@ def register_all_passes(self):
132135
)
133136
self.register_pass(RaiseLLVMToFunc.name, lambda: RaiseLLVMToFunc)
134137
self.register_pass(LowerAbbvToBvPass.name, lambda: LowerAbbvToBvPass)
138+
self.register_pass(RewriteSMTTensor.name, lambda: RewriteSMTTensor)
139+
self.register_pass(LowerSMTTensor.name, lambda: LowerSMTTensor)
135140

136141
def register_all_targets(self):
137142
super().register_all_targets()

xdsl_smt/dialects/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ def get_synth_dialect():
8383

8484
return SynthDialect
8585

86+
def get_tensor_dialect():
87+
from xdsl_smt.dialects.smt_tensor_dialect import SMTTensorDialect
88+
89+
return SMTTensorDialect
90+
8691
all_dialects["abbv"] = get_abbv_dialect
8792
all_dialects["pdl"] = get_pdl_dialect
8893
all_dialects["smt"] = get_smt_dialect
@@ -94,5 +99,6 @@ def get_synth_dialect():
9499
all_dialects["tv"] = get_tv_dialect
95100
all_dialects["hoare"] = get_hoare_dialect
96101
all_dialects["synth"] = get_synth_dialect
102+
all_dialects["tensor"] = get_tensor_dialect
97103

98104
return all_dialects

xdsl_smt/dialects/smt_tensor_dialect.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
from xdsl_smt.dialects.smt_bitvector_dialect import BitVectorType
3434

3535

36+
INDEX_WIDTH = 64
37+
IndexType = BitVectorType(INDEX_WIDTH)
38+
39+
3640
@irdl_attr_definition
3741
class SMTTensorType(
3842
Generic[AttributeCovT],
@@ -41,28 +45,33 @@ class SMTTensorType(
4145
ShapedType,
4246
ContainerType[AttributeCovT],
4347
):
44-
name = "smt.tensor"
48+
name = "smt.tensor.tensor"
4549

46-
shape: ArrayAttr[IntAttr]
50+
shape: ArrayAttr[IntegerAttr]
4751
element_type: AttributeCovT
4852
encoding: Attribute
4953

5054
def __init__(
5155
self,
5256
element_type: AttributeCovT,
53-
shape: Iterable[int] | Iterable[IntAttr],
57+
shape: Iterable[int] | Iterable[IntegerAttr],
5458
encoding: Attribute = NoneAttr(),
5559
):
5660
shape = ArrayAttr(
57-
[IntAttr(dim) if isinstance(dim, int) else dim for dim in shape]
61+
[
62+
IntegerAttr.from_int_and_width(dim, INDEX_WIDTH)
63+
if isinstance(dim, int)
64+
else dim
65+
for dim in shape
66+
]
5867
)
5968
super().__init__(shape, element_type, encoding)
6069

6170
def get_num_dims(self) -> int:
6271
return len(self.shape.data)
6372

6473
def get_shape(self) -> tuple[int, ...]:
65-
return tuple(i.data for i in self.shape.data)
74+
return tuple(i.value.data for i in self.shape.data)
6675

6776
def get_element_type(self) -> AttributeCovT:
6877
return self.element_type
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from xdsl_smt.dialects import smt_array_dialect as smt_array
2+
3+
from xdsl_smt.dialects.smt_dialect import (
4+
DeclareConstOp,
5+
)
6+
from xdsl_smt.dialects.smt_tensor_dialect import (
7+
IndexType,
8+
SMTTensorType,
9+
TensorExtractOp,
10+
)
11+
from xdsl.dialects.builtin import ModuleOp
12+
from xdsl.ir import Attribute
13+
from xdsl.utils.hints import isa
14+
from xdsl.context import Context
15+
from xdsl.pattern_rewriter import (
16+
GreedyRewritePatternApplier,
17+
PatternRewriteWalker,
18+
PatternRewriter,
19+
RewritePattern,
20+
op_type_rewrite_pattern,
21+
)
22+
from xdsl.passes import ModulePass
23+
24+
25+
def lower_tensor_type(typ: Attribute) -> Attribute:
26+
if isa(typ, SMTTensorType):
27+
result = typ.element_type
28+
index_type = IndexType
29+
for _ in typ.shape:
30+
result = smt_array.ArrayType(index_type, result)
31+
return result
32+
return typ
33+
34+
35+
class DeclareConstOpPattern(RewritePattern):
36+
@op_type_rewrite_pattern
37+
def match_and_rewrite(self, op: DeclareConstOp, rewriter: PatternRewriter):
38+
if isa(op.res.type, SMTTensorType):
39+
new_constant_op = DeclareConstOp(lower_tensor_type(op.res.type))
40+
rewriter.replace_matched_op(new_constant_op)
41+
42+
43+
class TensorExtractOpPattern(RewritePattern):
44+
@op_type_rewrite_pattern
45+
def match_and_rewrite(self, op: TensorExtractOp, rewriter: PatternRewriter):
46+
source = op.tensor
47+
assert isinstance(source.type, smt_array.ArrayType)
48+
select_ops: list[smt_array.SelectOp] = []
49+
for idx in op.indices:
50+
select_ops.append(smt_array.SelectOp(source, idx))
51+
source = select_ops[-1].res
52+
rewriter.replace_matched_op(select_ops)
53+
54+
55+
class LowerSMTTensor(ModulePass):
56+
name = "lower-smt-tensor"
57+
58+
def apply(self, ctx: Context, op: ModuleOp):
59+
walker = PatternRewriteWalker(
60+
GreedyRewritePatternApplier(
61+
[DeclareConstOpPattern(), TensorExtractOpPattern()]
62+
)
63+
)
64+
walker.rewrite_module(op)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from xdsl.transforms.common_subexpression_elimination import (
2+
CommonSubexpressionElimination,
3+
)
4+
5+
from xdsl.ir import SSAValue
6+
from xdsl_smt.dialects.smt_tensor_dialect import (
7+
TensorTransposeOp,
8+
TensorExtractOp,
9+
)
10+
from xdsl.dialects.builtin import ModuleOp
11+
from xdsl.context import Context
12+
from xdsl.pattern_rewriter import (
13+
GreedyRewritePatternApplier,
14+
PatternRewriteWalker,
15+
PatternRewriter,
16+
RewritePattern,
17+
op_type_rewrite_pattern,
18+
)
19+
from xdsl.passes import ModulePass
20+
21+
22+
class RewriteTransposeOpPattern(RewritePattern):
23+
@op_type_rewrite_pattern
24+
def match_and_rewrite(self, op: TensorTransposeOp, rewriter: PatternRewriter):
25+
for use in op.result.uses:
26+
extract_op = use.operation
27+
if isinstance(extract_op, TensorExtractOp):
28+
permutations = op.get_permutation()
29+
new_indices: list[SSAValue] = []
30+
for i in permutations:
31+
new_indices.append(extract_op.indices[i])
32+
new_extract_op = TensorExtractOp(op.operand, new_indices)
33+
rewriter.replace_op(extract_op, new_extract_op)
34+
if op.result.uses.get_length() == 0:
35+
rewriter.erase_matched_op()
36+
37+
38+
class RewriteSMTTensor(ModulePass):
39+
"""
40+
Rewrite patterns like `extract(op(arg))` to `extract(arg')`
41+
"""
42+
43+
name = "rewrite-smt-tensor"
44+
45+
def apply(self, ctx: Context, op: ModuleOp):
46+
walker = PatternRewriteWalker(
47+
GreedyRewritePatternApplier([RewriteTransposeOpPattern()])
48+
)
49+
walker.rewrite_module(op)
50+
CommonSubexpressionElimination().apply(ctx, op)

0 commit comments

Comments
 (0)