Skip to content

Commit db1a85b

Browse files
committed
Fix pdl-to-smt
1 parent 87a1ef3 commit db1a85b

File tree

1 file changed

+47
-2
lines changed

1 file changed

+47
-2
lines changed

xdsl_smt/passes/pdl_to_smt.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass, field
22
from typing import Callable, ClassVar, Sequence
3-
from xdsl.dialects.builtin import ModuleOp, IntegerType, StringAttr, UnitAttr
3+
from xdsl.dialects.builtin import ModuleOp, IntegerType, StringAttr, UnitAttr, ArrayAttr
44
from xdsl.dialects.pdl import (
55
ApplyNativeConstraintOp,
66
ApplyNativeRewriteOp,
@@ -290,8 +290,12 @@ def match_and_rewrite(self, op: ReplaceOp, rewriter: PatternRewriter):
290290
replacing_value = replacing_values[0]
291291

292292
if self.refinement is None:
293+
replaced_types = op.attributes["__replaced_types"]
294+
replacing_types = op.attributes["__replacing_types"]
295+
assert isa(replaced_types, ArrayAttr[Attribute])
296+
assert isa(replacing_types, ArrayAttr[Attribute])
293297
refinement = find_refinement_semantics(
294-
replaced_value.type, replacing_value.type
298+
replaced_types.data[0], replacing_types.data[0]
295299
)
296300
else:
297301
refinement = self.refinement
@@ -529,6 +533,46 @@ def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter):
529533
), "Operations used as computations in PDL should not have effects"
530534

531535

536+
def annotate_replace_op(pattern: PatternOp):
537+
"""
538+
Annotate all `pdl.replace` operations in the given pattern with the types of the
539+
replaced and replacing values.
540+
"""
541+
values_to_types: dict[SSAValue, tuple[Attribute, ...]] = {}
542+
for op in pattern.walk():
543+
if isinstance(op, OperandOp):
544+
if op.value_type is None:
545+
raise Exception("Cannot handle non-typed operands")
546+
owner = op.value_type.owner
547+
assert isinstance(owner, TypeOp)
548+
assert owner.constantType is not None
549+
values_to_types[op.value] = (owner.constantType,)
550+
if isinstance(op, OperationOp):
551+
values: list[Attribute] = []
552+
for type_value in op.type_values:
553+
owner = type_value.owner
554+
assert isinstance(owner, TypeOp)
555+
assert owner.constantType is not None
556+
values.append(owner.constantType)
557+
values_to_types[op.op] = tuple(values)
558+
if isinstance(op, ResultOp):
559+
values_to_types[op.val] = (
560+
values_to_types[op.parent_][op.index.value.data],
561+
)
562+
if isinstance(op, ReplaceOp):
563+
replaced_value_types = values_to_types[op.op_value]
564+
if len(op.repl_values) != 0:
565+
replacing_value_types: list[Attribute] = []
566+
for repl_value in op.repl_values:
567+
replacing_value_types.append(values_to_types[repl_value][0])
568+
replacing_types = tuple(replacing_value_types)
569+
else:
570+
assert op.repl_operation is not None
571+
replacing_types = values_to_types[op.repl_operation]
572+
op.attributes["__replaced_types"] = ArrayAttr(replaced_value_types)
573+
op.attributes["__replacing_types"] = ArrayAttr(replacing_types)
574+
575+
532576
@dataclass
533577
class PDLToSMTLowerer:
534578
native_rewrites: dict[
@@ -621,6 +665,7 @@ def lower_to_smt(self, module: ModuleOp, ctx: Context) -> None:
621665
new_state_op = DeclareConstOp(StateType())
622666
Rewriter.insert_op(new_state_op, insert_point)
623667
rewrite_context = PDLToSMTRewriteContext(new_state_op.res, new_state_op.res)
668+
annotate_replace_op(pattern)
624669

625670
walker = PatternRewriteWalker(
626671
GreedyRewritePatternApplier(

0 commit comments

Comments
 (0)