diff --git a/src/exo/frontend/boundscheck.py b/src/exo/frontend/boundscheck.py index b82bddc93..8fb72af58 100644 --- a/src/exo/frontend/boundscheck.py +++ b/src/exo/frontend/boundscheck.py @@ -569,9 +569,9 @@ def __init__(self, proc): for arg in proc.args: if arg.type.is_numeric(): shape = [lift_expr(s) for s in arg.type.shape()] - # check that all sizes/indices are positive + # check that all sizes/indices are non negative for s in shape: - self.check_pos_size(s) + self.check_non_negative(s) # check the bounds self.check_bounds(arg.name, shape, body_eff) @@ -848,16 +848,6 @@ def check_bounds(self, sym, shape, eff): for e in es: self.check_in_bounds(sym, shape, e, y) - def check_pos_size(self, expr): - e_pos = SMT.LT(SMT.Int(0), self.expr_to_smt(expr)) - if not self.solver.is_valid(e_pos): - eg = self.counter_example() - self.err( - expr, - f"expected expression {expr} to always be positive. " - f"It can be non positive when:\n {eg}.", - ) - def check_non_negative(self, expr): e_nn = SMT.LE(SMT.Int(0), self.expr_to_smt(expr)) if not self.solver.is_valid(e_nn): @@ -1002,9 +992,9 @@ def bd_pred(x, lo, hi, srcinfo): elif isinstance(stmt, LoopIR.Alloc): shape = [lift_expr(s) for s in stmt.type.shape()] - # check that all sizes are positive + # check that all sizes are non-negative for s in shape: - self.check_pos_size(s) + self.check_non_negative(s) # check that all accesses are in bounds self.check_bounds(stmt.name, shape, body_eff) body_eff = eff_remove_buf(stmt.name, body_eff) @@ -1022,9 +1012,9 @@ def bd_pred(x, lo, hi, srcinfo): pos_sz = SMT.LT(SMT.Int(0), self.sym_to_smt(sig.name)) self.solver.add_assertion(pos_sz) - # check the caller argument always be positive for sizes + # check the caller argument always be non-negative for sizes e_arg = lift_expr(arg) - self.check_pos_size(e_arg) + self.check_non_negative(e_arg) # Add type assertion from the caller types if arg.type.is_tensor_or_window() and not arg.type.is_win(): @@ -1038,10 +1028,10 @@ def bd_pred(x, lo, hi, srcinfo): bind[sig.name] = arg.name # need to check that the argument shape - # has all positive dimensions + # has all non-negative dimensions arg_shape = [lift_expr(s) for s in arg.type.shape()] for e in arg_shape: - self.check_pos_size(e) + self.check_non_negative(e) # also, need to check that the argument shape # is exactly the shape specified in the signature sig_shape = [ @@ -1052,22 +1042,11 @@ def bd_pred(x, lo, hi, srcinfo): else: bind[sig.name] = lift_expr(arg) - # map body of the subprocedure - self.preprocess_stmts(stmt.f.body) - eff = self.map_stmts(stmt.f.body, self.rec_proc_types(stmt.f)) - eff = eff.subst(bind) - - # translate effects occuring on windowed arguments - for sig, arg in zip(stmt.f.args, stmt.args): - if sig.type.is_numeric(): - if isinstance(arg.type, T.Window): - eff = self.translate_eff(eff, sig.name, arg.type, type_env) - # Check that asserts are correct for p in stmt.f.preds: p_subst = loopir_subst(p, subst) - smt_pred = self.expr_to_smt(lift_expr(p_subst)) - if not self.solver.is_valid(smt_pred): + subst_pred = self.expr_to_smt(lift_expr(p_subst)) + if not self.solver.is_valid(subst_pred): eg = self.counter_example() self.err( stmt, @@ -1076,6 +1055,21 @@ def bd_pred(x, lo, hi, srcinfo): f" Assertion is false when:\n {eg}", ) + # Add assertion to SMT when/if this assertions is satisfiable + smt_p = self.expr_to_smt(lift_expr(p)) + self.solver.add_assertion(smt_p) + + # map body of the subprocedure + self.preprocess_stmts(stmt.f.body) + eff = self.map_stmts(stmt.f.body, self.rec_proc_types(stmt.f)) + eff = eff.subst(bind) + + # translate effects occuring on windowed arguments + for sig, arg in zip(stmt.f.args, stmt.args): + if sig.type.is_numeric(): + if isinstance(arg.type, T.Window): + eff = self.translate_eff(eff, sig.name, arg.type, type_env) + self.pop() body_eff = eff_concat(eff, body_eff) diff --git a/src/exo/rewrite/LoopIR_unification.py b/src/exo/rewrite/LoopIR_unification.py index 16cf2182d..9a2b27c26 100644 --- a/src/exo/rewrite/LoopIR_unification.py +++ b/src/exo/rewrite/LoopIR_unification.py @@ -20,6 +20,7 @@ from ..core.prelude import * from .new_eff import Check_Aliasing import exo.core.internal_cursors as ic +from ..frontend.boundscheck import CheckBounds def _get_smt_solver(): @@ -94,6 +95,8 @@ def DoReplace(subproc, block_cursor): ir, fwd = block_cursor._replace([new_call]) Check_Aliasing(ir) + print(ir) + CheckBounds(ir) return ir, fwd diff --git a/tests/golden/test_schedules/test_unify11.txt b/tests/golden/test_schedules/test_unify11.txt index d5914d685..3dc1fa3f5 100644 --- a/tests/golden/test_schedules/test_unify11.txt +++ b/tests/golden/test_schedules/test_unify11.txt @@ -2,4 +2,4 @@ def foo(n: size, m: size, x: f32[n] @ DRAM): assert -m + n >= 1 assert -m + n <= 8 y: f32[8] @ DRAM - bar(y[0:8], x[0:8], -n + m) \ No newline at end of file + bar2(y[0:8], x[0:8], -n + m) \ No newline at end of file diff --git a/tests/test_schedules.py b/tests/test_schedules.py index a7a8bfa56..2a8715066 100644 --- a/tests/test_schedules.py +++ b/tests/test_schedules.py @@ -2536,7 +2536,13 @@ def foo(n: size, m: size, x: f32[n]): def test_unify11(golden): @proc - def bar(dst: [f32][8], src: [f32][8], bound: size): + def bar1(dst: [f32][8], src: [f32][8], bound: size): + for i in seq(0, 8): + if i < bound: + dst[i] = src[i] + + @proc + def bar2(dst: [f32][8], src: [f32][8], bound: index): for i in seq(0, 8): if i < bound: dst[i] = src[i] @@ -2550,13 +2556,16 @@ def foo(n: size, m: size, x: f32[n]): if m > n + i: y[i] = x[i] - foo = replace(foo, foo.find_loop("i"), bar) + with pytest.raises(TypeError, match="expected expression"): + replace(foo, foo.find_loop("i"), bar1) + + foo = replace(foo, foo.find_loop("i"), bar2) assert str(simplify(foo)) == golden def test_unify12(golden): @proc - def bar(dst: [f32][8], src: [f32][8], bound: size): + def bar(dst: [f32][8], src: [f32][8], bound: index): for i in seq(0, 8): if i < bound: dst[i] = src[i] diff --git a/tests/test_x86.py b/tests/test_x86.py index 57547c02c..6d1b423dc 100644 --- a/tests/test_x86.py +++ b/tests/test_x86.py @@ -169,6 +169,10 @@ def sgemm_6x16( A: [f32][6, K] @ DRAM, B: [f32][K, 16] @ DRAM, ): + assert stride(C, 1) == 1 + assert stride(A, 1) == 1 + assert stride(B, 1) == 1 + for i in seq(0, 6): for j in seq(0, 16): for k in seq(0, K):