Skip to content

Commit 219f9dd

Browse files
committed
Add missing operations and make pdl-to-lean work
1 parent 823ba74 commit 219f9dd

File tree

1 file changed

+66
-10
lines changed

1 file changed

+66
-10
lines changed

xdsl_smt/utils/pdl_to_lean.py

Lines changed: 66 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from xdsl.dialects import pdl, smt
66
from xdsl.ir import SSAValue
7-
from xdsl_smt.dialects.smt_bitvector_dialect import AddOp
7+
from xdsl.context import Context
8+
from xdsl_smt.dialects import get_all_dialects
89

910

1011
def get_replace_terminator(pattern: pdl.PatternOp) -> pdl.ReplaceOp:
@@ -19,33 +20,61 @@ def get_replace_terminator(pattern: pdl.PatternOp) -> pdl.ReplaceOp:
1920
return replace
2021

2122

23+
unary_lean_name: dict[str, str] = {
24+
"smt.bv.neg": "-",
25+
"smt.bv.not": "BitVec.not",
26+
}
27+
2228
op_name_to_lean_infix: dict[str, str] = {
2329
"smt.eq": "=",
2430
"smt.distinct": "≠",
2531
"smt.bv.sub": "-",
2632
"smt.bv.mul": "*",
2733
"smt.bv.add": "+",
34+
"smt.bv.and": "&&&",
35+
"smt.bv.or": "|||",
36+
"smt.bv.xor": "^^^",
37+
"smt.bv.lshr": ">>>",
38+
"smt.bv.shl": "<<<",
39+
}
40+
41+
op_name_to_lean_prefix: dict[str, str] = {
42+
"smt.bv.ashr": "BitVec.sshiftRight'",
2843
}
2944

3045

3146
def get_lean_expression(value: SSAValue, operand_to_str: dict[SSAValue, str]) -> str:
3247
if value in operand_to_str:
3348
return operand_to_str[value]
49+
if isinstance((pdl_op := value.owner), pdl.ResultOp):
50+
return get_lean_expression(pdl_op.parent_, operand_to_str)
3451
if isinstance((pdl_op := value.owner), pdl.OperationOp):
52+
operands = [
53+
get_lean_expression(operand, operand_to_str)
54+
for operand in pdl_op.operand_values
55+
]
3556
assert pdl_op.opName is not None
3657
op_name = pdl_op.opName.data
37-
lean_infix = op_name_to_lean_infix[op_name]
38-
lhs = get_lean_expression(pdl_op.operand_values[0], operand_to_str)
39-
rhs = get_lean_expression(pdl_op.operand_values[1], operand_to_str)
40-
return f"({lhs} {lean_infix} {rhs})"
58+
if len(operands) == 1:
59+
return f"({unary_lean_name[op_name]} {operands[0]})"
60+
elif len(operands) == 2:
61+
lhs, rhs = operands
62+
if op_name in op_name_to_lean_infix.keys():
63+
lean_infix = op_name_to_lean_infix[op_name]
64+
return f"({lhs} {lean_infix} {rhs})"
65+
if op_name in op_name_to_lean_prefix.keys():
66+
lean_prefix = op_name_to_lean_prefix[op_name]
67+
return f"({lean_prefix} {lhs} {rhs})"
68+
assert False, f"Unsupported operation: {op_name}"
4169
assert False, f"Unsupported owner: {value.owner}"
4270

4371

4472
def get_lean_type(value: SSAValue) -> str:
45-
assert isinstance(value.owner, pdl.TypeOp)
73+
if not isinstance(value.owner, pdl.TypeOp):
74+
raise ValueError(f"Only pdl.type values are supported, got {value.owner}")
4675
assert (MlirType := value.owner.constantType) is not None
4776
assert isinstance(MlirType, smt.BitVectorType)
48-
return f"BitVec {MlirType.width}"
77+
return f"BitVec {MlirType.width.data}"
4978

5079

5180
def pdl_to_lean(pattern: pdl.PatternOp) -> str:
@@ -57,7 +86,8 @@ def pdl_to_lean(pattern: pdl.PatternOp) -> str:
5786
for op in pattern.walk():
5887
if isinstance(op, pdl.OperandOp):
5988
operand_to_str[op.value] = names[num_names]
60-
operand_to_type[op.value] = get_lean_type(op.value)
89+
assert op.value_type is not None
90+
operand_to_type[op.value] = get_lean_type(op.value_type)
6191
num_names += 1
6292

6393
replace_op = get_replace_terminator(pattern)
@@ -70,9 +100,35 @@ def pdl_to_lean(pattern: pdl.PatternOp) -> str:
70100
rhs_expr = get_lean_expression(replace_op.repl_values[0], operand_to_str)
71101

72102
# Convert the PDL pattern to a Lean theorem statement.
73-
assert pattern.sym_name is not None
103+
# assert pattern.sym_name is not None
74104
arguments_str = " ".join(
75105
f"({operand_to_str[value]} : {operand_to_type[value]})"
76106
for value in operand_to_str.keys()
77107
)
78-
return f"theorem {pattern.sym_name} {arguments_str}:\n {lhs_expr} = {rhs_expr} := by\n sorry"
108+
return f"example {arguments_str}:\n {lhs_expr} = {rhs_expr} := by\n bv_decide +acNf"
109+
110+
111+
if __name__ == "__main__":
112+
import sys
113+
114+
from xdsl.parser import Parser
115+
116+
if len(sys.argv) != 2:
117+
print("Usage: pdl_to_lean.py <pdl_file>")
118+
sys.exit(1)
119+
120+
pdl_file = sys.argv[1]
121+
with open(pdl_file, "r") as f:
122+
pdl_text = f.read()
123+
124+
ctx = Context()
125+
for name, factory in get_all_dialects().items():
126+
ctx.register_dialect(name, factory)
127+
128+
module = Parser(ctx, pdl_text).parse_module()
129+
130+
for pattern in module.walk():
131+
if isinstance(pattern, pdl.PatternOp):
132+
lean_theorem = pdl_to_lean(pattern)
133+
# aprint(pattern)
134+
print(lean_theorem)

0 commit comments

Comments
 (0)