Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/filecheck/dialects/tensor-theory/ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// RUN: xdsl-smt "%s" -p=lower-smt-tensor -t=smt | filecheck "%s"

// Lower tensor from "smt.tensor" operations.

builtin.module {
%0 = "smt.declare_const"() : () -> !smt.tensor.tensor<[3, 3], !smt.fp<8, 24>, none>
// CHECK: (declare-const $tmp (Array (_ BitVec 64) (Array (_ BitVec 64) (_ FloatingPoint 8 24))

%idx1 = "smt.declare_const"() : () -> !smt.bv<64>
%idx2 = "smt.declare_const"() : () -> !smt.bv<64>
// CHECK-NEXT: (declare-const $idx1 (_ BitVec 64))
// CHECK-NEXT: (declare-const $idx2 (_ BitVec 64))

%extract = "smt.tensor.extract"(%0, %idx1, %idx2):
(!smt.tensor.tensor<[3, 3], !smt.fp<8, 24>, none>, !smt.bv<64>, !smt.bv<64>) -> !smt.fp<8, 24>
%zero = "smt.fp.pzero"() : () -> !smt.fp<8,24>
%eq_zero = "smt.eq"(%extract, %zero) : (!smt.fp<8,24>, !smt.fp<8,24>) -> !smt.bool
"smt.assert"(%eq_zero) : (!smt.bool) -> ()
// CHECK-NEXT: (assert (= (select (select $tmp $idx1) $idx2) (_ +zero 8 24)))
}
18 changes: 18 additions & 0 deletions tests/filecheck/rewrite-tensors/rewrite-smt-tensor.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: xdsl-smt "%s" -p=rewrite-smt-tensor | filecheck "%s"

// Rewrite tensor from "smt.extract" operations.

builtin.module {
%0 = "smt.declare_const"() : () -> !smt.tensor.tensor<[3, 3], !smt.fp<8, 24>, none>
%idx1 = "smt.declare_const"() : () -> !smt.bv<64>
%idx2 = "smt.declare_const"() : () -> !smt.bv<64>
%transpose = "smt.tensor.transpose"(%0){permutation = [1, 0]} :
(!smt.tensor.tensor<[3, 3], !smt.fp<8, 24>, none>) -> !smt.tensor.tensor<[3, 3], !smt.fp<8, 24>, none>
%extract = "smt.tensor.extract"(%transpose, %idx1, %idx2):
(!smt.tensor.tensor<[3, 3], !smt.fp<8, 24>, none>, !smt.bv<64>, !smt.bv<64>) -> !smt.fp<8, 24>
}

// CHECK: %0 = "smt.declare_const"() : () -> !smt.tensor.tensor<[3 : i64, 3 : i64], !smt.fp<8, 24>, none>
// CHECK-NEXT: %idx1 = "smt.declare_const"() : () -> !smt.bv<64>
// CHECK-NEXT: %idx2 = "smt.declare_const"() : () -> !smt.bv<64>
// CHECK-NEXT: %extract = "smt.tensor.extract"(%0, %idx2, %idx1) : (!smt.tensor.tensor<[3 : i64, 3 : i64], !smt.fp<8, 24>, none>, !smt.bv<64>, !smt.bv<64>) -> !smt.fp<8, 24>
5 changes: 5 additions & 0 deletions xdsl_smt/cli/xdsl_smt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from xdsl_smt.passes.load_parametric_int_semantics import LoadIntSemanticsPass
from xdsl_smt.passes.lower_memory_effects import LowerMemoryEffectsPass
from xdsl_smt.passes.lower_effects_with_memory import LowerEffectsWithMemoryPass
from xdsl_smt.passes.lower_smt_tensor import LowerSMTTensor
from xdsl_smt.passes.merge_func_results import MergeFuncResultsPass
from xdsl_smt.passes.lower_memory_to_array import LowerMemoryToArrayPass
from xdsl_smt.passes.raise_llvm_to_func import RaiseLLVMToFunc
Expand All @@ -47,6 +48,7 @@
from xdsl_smt.passes.lower_pairs import LowerPairs
from xdsl_smt.passes.lower_to_smt import LowerToSMTPass
from xdsl_smt.passes.lower_ub_to_pairs import LowerUBToPairs
from xdsl_smt.passes.rewrite_smt_tensor import RewriteSMTTensor
from xdsl_smt.passes.smt_expand import SMTExpand
from xdsl_smt.passes.pdl_add_implicit_properties import PDLAddImplicitPropertiesPass

Expand Down Expand Up @@ -107,6 +109,7 @@ def register_all_dialects(self):
self.ctx.load_registered_dialect(SMTUtilsDialect.name)
self.ctx.load_registered_dialect(SMTArray.name)
self.ctx.load_registered_dialect(SMTFloatingPointDialect.name)
self.ctx.load_registered_dialect(SMTTensorDialect.name)

def register_all_passes(self):
super().register_all_passes()
Expand All @@ -130,6 +133,8 @@ def register_all_passes(self):
)
self.register_pass(RaiseLLVMToFunc.name, lambda: RaiseLLVMToFunc)
self.register_pass(LowerAbbvToBvPass.name, lambda: LowerAbbvToBvPass)
self.register_pass(RewriteSMTTensor.name, lambda: RewriteSMTTensor)
self.register_pass(LowerSMTTensor.name, lambda: LowerSMTTensor)

def register_all_targets(self):
super().register_all_targets()
Expand Down
6 changes: 6 additions & 0 deletions xdsl_smt/dialects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def get_synth_dialect():

return SynthDialect

def get_tensor_dialect():
from xdsl_smt.dialects.smt_tensor_dialect import SMTTensorDialect

return SMTTensorDialect

all_dialects["abbv"] = get_abbv_dialect
all_dialects["pdl"] = get_pdl_dialect
all_dialects["smt"] = get_smt_dialect
Expand All @@ -94,5 +99,6 @@ def get_synth_dialect():
all_dialects["tv"] = get_tv_dialect
all_dialects["hoare"] = get_hoare_dialect
all_dialects["synth"] = get_synth_dialect
all_dialects["tensor"] = get_tensor_dialect

return all_dialects
19 changes: 14 additions & 5 deletions xdsl_smt/dialects/smt_tensor_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
from xdsl_smt.dialects.smt_bitvector_dialect import BitVectorType


INDEX_WIDTH = 64
IndexType = BitVectorType(INDEX_WIDTH)


@irdl_attr_definition
class SMTTensorType(
Generic[AttributeCovT],
Expand All @@ -41,28 +45,33 @@ class SMTTensorType(
ShapedType,
ContainerType[AttributeCovT],
):
name = "smt.tensor"
name = "smt.tensor.tensor"

shape: ArrayAttr[IntAttr]
shape: ArrayAttr[IntegerAttr]
element_type: AttributeCovT
encoding: Attribute

def __init__(
self,
element_type: AttributeCovT,
shape: Iterable[int] | Iterable[IntAttr],
shape: Iterable[int] | Iterable[IntegerAttr],
encoding: Attribute = NoneAttr(),
):
shape = ArrayAttr(
[IntAttr(dim) if isinstance(dim, int) else dim for dim in shape]
[
IntegerAttr.from_int_and_width(dim, INDEX_WIDTH)
if isinstance(dim, int)
else dim
for dim in shape
]
)
super().__init__(shape, element_type, encoding)

def get_num_dims(self) -> int:
return len(self.shape.data)

def get_shape(self) -> tuple[int, ...]:
return tuple(i.data for i in self.shape.data)
return tuple(i.value.data for i in self.shape.data)

def get_element_type(self) -> AttributeCovT:
return self.element_type
Expand Down
64 changes: 64 additions & 0 deletions xdsl_smt/passes/lower_smt_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from xdsl_smt.dialects import smt_array_dialect as smt_array
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this file doesn't use the SMTLowerer and use the OperationSemantics class?


from xdsl_smt.dialects.smt_dialect import (
DeclareConstOp,
)
from xdsl_smt.dialects.smt_tensor_dialect import (
IndexType,
SMTTensorType,
TensorExtractOp,
)
from xdsl.dialects.builtin import ModuleOp
from xdsl.ir import Attribute
from xdsl.utils.hints import isa
from xdsl.context import Context
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriteWalker,
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.passes import ModulePass


def lower_tensor_type(typ: Attribute) -> Attribute:
if isa(typ, SMTTensorType):
result = typ.element_type
index_type = IndexType
for _ in typ.shape:
result = smt_array.ArrayType(index_type, result)
return result
return typ


class DeclareConstOpPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: DeclareConstOp, rewriter: PatternRewriter):
if isa(op.res.type, SMTTensorType):
new_constant_op = DeclareConstOp(lower_tensor_type(op.res.type))
rewriter.replace_matched_op(new_constant_op)


class TensorExtractOpPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: TensorExtractOp, rewriter: PatternRewriter):
source = op.tensor
assert isinstance(source.type, smt_array.ArrayType)
select_ops: list[smt_array.SelectOp] = []
for idx in op.indices:
select_ops.append(smt_array.SelectOp(source, idx))
source = select_ops[-1].res
rewriter.replace_matched_op(select_ops)


class LowerSMTTensor(ModulePass):
name = "lower-smt-tensor"

def apply(self, ctx: Context, op: ModuleOp):
walker = PatternRewriteWalker(
GreedyRewritePatternApplier(
[DeclareConstOpPattern(), TensorExtractOpPattern()]
)
)
walker.rewrite_module(op)
50 changes: 50 additions & 0 deletions xdsl_smt/passes/rewrite_smt_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from xdsl.transforms.common_subexpression_elimination import (
CommonSubexpressionElimination,
)

from xdsl.ir import SSAValue
from xdsl_smt.dialects.smt_tensor_dialect import (
TensorTransposeOp,
TensorExtractOp,
)
from xdsl.dialects.builtin import ModuleOp
from xdsl.context import Context
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriteWalker,
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.passes import ModulePass


class RewriteTransposeOpPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: TensorTransposeOp, rewriter: PatternRewriter):
for use in op.result.uses:
extract_op = use.operation
if isinstance(extract_op, TensorExtractOp):
permutations = op.get_permutation()
new_indices: list[SSAValue] = []
for i in permutations:
new_indices.append(extract_op.indices[i])
new_extract_op = TensorExtractOp(op.operand, new_indices)
rewriter.replace_op(extract_op, new_extract_op)
if op.result.uses.get_length() == 0:
rewriter.erase_matched_op()


class RewriteSMTTensor(ModulePass):
"""
Rewrite patterns like `extract(op(arg))` to `extract(arg')`
"""

name = "rewrite-smt-tensor"

def apply(self, ctx: Context, op: ModuleOp):
walker = PatternRewriteWalker(
GreedyRewritePatternApplier([RewriteTransposeOpPattern()])
)
walker.rewrite_module(op)
CommonSubexpressionElimination().apply(ctx, op)