Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions src/exo/rewrite/LoopIR_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ..core.prelude import *
from .new_eff import Check_Aliasing
import exo.core.internal_cursors as ic
from .new_analysis_core import A, SMTSolver


def _get_smt_solver():
Expand Down Expand Up @@ -88,9 +89,42 @@ def DoReplace(subproc, block_cursor):
stmts = [c._node for c in block_cursor[:n_stmts]]
live_vars = Get_Live_Variables(block_cursor[0])
new_args = Unification(temp_subproc, stmts, live_vars).result()
const_args = {k: v for k, v in new_args.items() if isinstance(v, LoopIR.Const)}

slv = SMTSolver(verbose=False)
slv.push()

def lift_expr(e):
if isinstance(e, LoopIR.Read):
if e.name in const_args:
return lift_expr(const_args[e.name])
else:
return A.Var(e.name, e.type, e.srcinfo)

elif isinstance(e, LoopIR.Const):
return A.Const(e.val, e.type, e.srcinfo)

elif isinstance(e, LoopIR.BinOp):
lhs = lift_expr(e.lhs)
rhs = lift_expr(e.rhs)
if lhs and rhs:
return A.BinOp(e.op, lhs, rhs, e.type, e.srcinfo)

return None

for pred in temp_subproc.preds:
lifted_pred = lift_expr(pred)

if lifted_pred and not slv.satisfy(lifted_pred):
raise TypeError(
pred,
f"The assertion {lifted_pred} at {lifted_pred.srcinfo} is always unsatisfiable.",
)

slv.pop()

# but don't use a different LoopIR.proc for the callsite itself
new_call = LoopIR.Call(subproc, new_args, stmts[0].srcinfo)
new_call = LoopIR.Call(subproc, list(new_args.values()), stmts[0].srcinfo)

ir, fwd = block_cursor._replace([new_call])
Check_Aliasing(ir)
Expand Down Expand Up @@ -707,7 +741,7 @@ def get_arg(fa):
bufvar = self.buf_holes[fa.name]
return bufvar.get_solution(self, solutions, stmt_block[0].srcinfo)

self.new_args = [get_arg(fa) for fa in subproc.args]
self.new_args = dict([(fa.name, get_arg(fa)) for fa in subproc.args])

def err(self):
raise TypeError("subproc and pattern don't match")
Expand Down
Loading