Skip to content

Fix unification #783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 25 additions & 31 deletions src/exo/frontend/boundscheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,9 +569,9 @@ def __init__(self, proc):
for arg in proc.args:
if arg.type.is_numeric():
shape = [lift_expr(s) for s in arg.type.shape()]
# check that all sizes/indices are positive
# check that all sizes/indices are non negative
for s in shape:
self.check_pos_size(s)
self.check_non_negative(s)
# check the bounds
self.check_bounds(arg.name, shape, body_eff)

Expand Down Expand Up @@ -848,16 +848,6 @@ def check_bounds(self, sym, shape, eff):
for e in es:
self.check_in_bounds(sym, shape, e, y)

def check_pos_size(self, expr):
e_pos = SMT.LT(SMT.Int(0), self.expr_to_smt(expr))
if not self.solver.is_valid(e_pos):
eg = self.counter_example()
self.err(
expr,
f"expected expression {expr} to always be positive. "
f"It can be non positive when:\n {eg}.",
)

def check_non_negative(self, expr):
e_nn = SMT.LE(SMT.Int(0), self.expr_to_smt(expr))
if not self.solver.is_valid(e_nn):
Expand Down Expand Up @@ -1002,9 +992,9 @@ def bd_pred(x, lo, hi, srcinfo):

elif isinstance(stmt, LoopIR.Alloc):
shape = [lift_expr(s) for s in stmt.type.shape()]
# check that all sizes are positive
# check that all sizes are non-negative
for s in shape:
self.check_pos_size(s)
self.check_non_negative(s)
# check that all accesses are in bounds
self.check_bounds(stmt.name, shape, body_eff)
body_eff = eff_remove_buf(stmt.name, body_eff)
Expand All @@ -1022,9 +1012,9 @@ def bd_pred(x, lo, hi, srcinfo):
pos_sz = SMT.LT(SMT.Int(0), self.sym_to_smt(sig.name))
self.solver.add_assertion(pos_sz)

# check the caller argument always be positive for sizes
# check the caller argument always be non-negative for sizes
e_arg = lift_expr(arg)
self.check_pos_size(e_arg)
self.check_non_negative(e_arg)

# Add type assertion from the caller types
if arg.type.is_tensor_or_window() and not arg.type.is_win():
Expand All @@ -1038,10 +1028,10 @@ def bd_pred(x, lo, hi, srcinfo):
bind[sig.name] = arg.name

# need to check that the argument shape
# has all positive dimensions
# has all non-negative dimensions
arg_shape = [lift_expr(s) for s in arg.type.shape()]
for e in arg_shape:
self.check_pos_size(e)
self.check_non_negative(e)
# also, need to check that the argument shape
# is exactly the shape specified in the signature
sig_shape = [
Expand All @@ -1052,22 +1042,11 @@ def bd_pred(x, lo, hi, srcinfo):
else:
bind[sig.name] = lift_expr(arg)

# map body of the subprocedure
self.preprocess_stmts(stmt.f.body)
eff = self.map_stmts(stmt.f.body, self.rec_proc_types(stmt.f))
eff = eff.subst(bind)

# translate effects occuring on windowed arguments
for sig, arg in zip(stmt.f.args, stmt.args):
if sig.type.is_numeric():
if isinstance(arg.type, T.Window):
eff = self.translate_eff(eff, sig.name, arg.type, type_env)

# Check that asserts are correct
for p in stmt.f.preds:
p_subst = loopir_subst(p, subst)
smt_pred = self.expr_to_smt(lift_expr(p_subst))
if not self.solver.is_valid(smt_pred):
subst_pred = self.expr_to_smt(lift_expr(p_subst))
if not self.solver.is_valid(subst_pred):
eg = self.counter_example()
self.err(
stmt,
Expand All @@ -1076,6 +1055,21 @@ def bd_pred(x, lo, hi, srcinfo):
f" Assertion is false when:\n {eg}",
)

# Add assertion to SMT when/if this assertions is satisfiable
smt_p = self.expr_to_smt(lift_expr(p))
self.solver.add_assertion(smt_p)

# map body of the subprocedure
self.preprocess_stmts(stmt.f.body)
eff = self.map_stmts(stmt.f.body, self.rec_proc_types(stmt.f))
eff = eff.subst(bind)

# translate effects occuring on windowed arguments
for sig, arg in zip(stmt.f.args, stmt.args):
if sig.type.is_numeric():
if isinstance(arg.type, T.Window):
eff = self.translate_eff(eff, sig.name, arg.type, type_env)

self.pop()

body_eff = eff_concat(eff, body_eff)
Expand Down
3 changes: 3 additions & 0 deletions src/exo/rewrite/LoopIR_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ..core.prelude import *
from .new_eff import Check_Aliasing
import exo.core.internal_cursors as ic
from ..frontend.boundscheck import CheckBounds


def _get_smt_solver():
Expand Down Expand Up @@ -94,6 +95,8 @@ def DoReplace(subproc, block_cursor):

ir, fwd = block_cursor._replace([new_call])
Check_Aliasing(ir)
print(ir)
CheckBounds(ir)
return ir, fwd


Expand Down
2 changes: 1 addition & 1 deletion tests/golden/test_schedules/test_unify11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ def foo(n: size, m: size, x: f32[n] @ DRAM):
assert -m + n >= 1
assert -m + n <= 8
y: f32[8] @ DRAM
bar(y[0:8], x[0:8], -n + m)
bar2(y[0:8], x[0:8], -n + m)
15 changes: 12 additions & 3 deletions tests/test_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2536,7 +2536,13 @@ def foo(n: size, m: size, x: f32[n]):

def test_unify11(golden):
@proc
def bar(dst: [f32][8], src: [f32][8], bound: size):
def bar1(dst: [f32][8], src: [f32][8], bound: size):
for i in seq(0, 8):
if i < bound:
dst[i] = src[i]

@proc
def bar2(dst: [f32][8], src: [f32][8], bound: index):
for i in seq(0, 8):
if i < bound:
dst[i] = src[i]
Expand All @@ -2550,13 +2556,16 @@ def foo(n: size, m: size, x: f32[n]):
if m > n + i:
y[i] = x[i]

foo = replace(foo, foo.find_loop("i"), bar)
with pytest.raises(TypeError, match="expected expression"):
replace(foo, foo.find_loop("i"), bar1)

foo = replace(foo, foo.find_loop("i"), bar2)
assert str(simplify(foo)) == golden


def test_unify12(golden):
@proc
def bar(dst: [f32][8], src: [f32][8], bound: size):
def bar(dst: [f32][8], src: [f32][8], bound: index):
for i in seq(0, 8):
if i < bound:
dst[i] = src[i]
Expand Down
4 changes: 4 additions & 0 deletions tests/test_x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def sgemm_6x16(
A: [f32][6, K] @ DRAM,
B: [f32][K, 16] @ DRAM,
):
assert stride(C, 1) == 1
assert stride(A, 1) == 1
assert stride(B, 1) == 1

for i in seq(0, 6):
for j in seq(0, 16):
for k in seq(0, K):
Expand Down
Loading