diff --git a/.gitignore b/.gitignore index ac89d54c..c55af033 100644 --- a/.gitignore +++ b/.gitignore @@ -25,4 +25,8 @@ configure dependencies/chipyard +examples/*.c +examples/*.d +examples/*.h + .vscode diff --git a/examples/arm_matmul.py b/examples/arm_matmul.py new file mode 100644 index 00000000..a95f5544 --- /dev/null +++ b/examples/arm_matmul.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import os +import sys + +from exo import proc +from exo.platforms.neon import * +from exo.stdlib.scheduling import * + +# Hide output when running through exocc. +if __name__ != "__main__" and hasattr(os, "devnull"): + sys.stdout = open(os.devnull, "w") + +# Algorithm definition +@proc +def rank_k_reduce_6x16( + K: size, A: f32[6, K] @ DRAM, B: f32[K, 16] @ DRAM, C: f32[6, 16] @ DRAM +): + for i in seq(0, 6): + for j in seq(0, 16): + for k in seq(0, K): + C[i, j] += A[i, k] * B[k, j] + + +print("\n============= original ==============") +print(rank_k_reduce_6x16) + +print("\n============= reorder loops ==============") +neon = rename(rank_k_reduce_6x16, "rank_k_reduce_6x16_scheduled") +neon = reorder_loops(neon, "j k") +neon = reorder_loops(neon, "i k") +print(neon) + + +print("\n============= divide loop ==============") +# neon only supports vectors of width 4 for f32 +# x86 supports either 4 or 8 wide +vec_reg_width = 4 +neon = divide_loop(neon, "for j in _: _", vec_reg_width, ["jo", "ji"], perfect=True) +print(neon) + +print("\n============= stage mem ==============") +# we want the computation to be "output stationary", which means, +# we want to preallocate all the output registers at the start. +# The staging of C will cause us to consume 12 out of the 16 vector registers +neon = stage_mem(neon, "for k in _:_", "C[0:6, 0:16]", "C_reg") +print(neon) +neon = simplify(neon) + +print("\n============= reshape C_reg ==============") +# Reshape C_reg so we can map it into vector registers +neon = divide_dim(neon, "C_reg:_", 1, vec_reg_width) +print(neon) + +print("\n============= divide loop ==============") +neon = repeat(divide_loop)( + neon, "for i1 in _: _", vec_reg_width, ["i2", "i3"], perfect=True +) +neon = simplify(neon) +print(neon) + +print("\n============= map C_reg ops ==============") +# Map C_reg operations to vector instructions +neon = set_memory(neon, "C_reg:_", Neon) +# this loads 8 items into the register but neon only loads 4 +neon = replace_all(neon, neon_vld_4xf32) +neon = replace_all(neon, neon_vst_4xf32) +neon = simplify(neon) +print(neon) + + +# Now, the rest of the compute needs to work with the constraint that the +# we only have 4 more registers to work with here. + +print("\n============= stage B_reg ==============") +# B is easy, it is just two vector loads +neon = stage_mem(neon, "for i in _:_", "B[k, 0:16]", "B_reg") +neon = simplify(neon) +print(neon) + +print("\n============= block 1st B_reg load ==============") +neon = divide_loop(neon, "for i0 in _: _ #1", vec_reg_width, ["io", "ii"], perfect=True) +print(neon) + +print("\n============= reshape B_reg ==============") +neon = divide_dim(neon, "B_reg:_", 0, vec_reg_width) +print(neon) + +print("\n============= map B_reg ops ==============") +neon = set_memory(neon, "B_reg:_", Neon) +neon = simplify(neon) +neon = replace_all(neon, neon_vld_4xf32) +neon = simplify(neon) +print(neon) + +# Now we've used up two more vector registers. +# The final part is staging A + +print("\n============= stage A_reg ==============") +neon = bind_expr(neon, "A[i, k]", "A_reg") +neon = expand_dim(neon, "A_reg", vec_reg_width, "ji") +neon = lift_alloc(neon, "A_reg", n_lifts=2) +neon = fission(neon, neon.find("A_reg[ji] = _").after(), n_lifts=2) +neon = remove_loop(neon, "for jo in _: _") +neon = set_memory(neon, "A_reg:_", Neon) +neon = replace_all(neon, neon_broadcast_4xf32) +neon = simplify(neon) +print(neon) + + +# DO THE COMPUTE!!! +print("\n============= map mult add op ==============") +neon = replace_all(neon, neon_vfmadd_4xf32_4xf32) +neon = simplify(neon) +print(neon) + +print("\n============= dnone! ==============") diff --git a/examples/avx2_matmul/Makefile b/examples/avx2_matmul/Makefile index 28acb78f..832e5a99 100644 --- a/examples/avx2_matmul/Makefile +++ b/examples/avx2_matmul/Makefile @@ -1,13 +1,25 @@ CFLAGS ?= -march=native -avx2_matmul: avx2_matmul.o main.o +.PHONY: x86 +x86: avx2_matmul -avx2_matmul.c: x86_matmul.py +# x86 build +avx2_matmul: avx2_matmul.o main.o +avx2_matmul.h avx2_matmul.c: x86_matmul.py exocc -o . --stem $(*F) $^ -main.c: avx2_matmul.c +.PHONY: neon +neon: neon_matmul + +# ARM +neon_matmul: neon_matmul.o main.o +neon_matmul.h neon_matmul.c: arm_matmul.py + exocc -o . --stem $(*F) $^ .PHONY: clean clean: - $(RM) avx2_matmul avx2_matmul.* *.o exo_demo + $(RM) *.o exo_demo $(RM) -r __pycache__/ + $(RM) avx2_matmul avx2_matmul.* + $(RM) neon_matmul neon_matmul.* + diff --git a/examples/avx2_matmul/main.c b/examples/avx2_matmul/main.c index 1d2606ac..dc39c947 100644 --- a/examples/avx2_matmul/main.c +++ b/examples/avx2_matmul/main.c @@ -1,7 +1,12 @@ +#include #include #include -#include "avx2_matmul.h" +// generated from exo +void rank_k_reduce_6x16( + void *ctxt, int_fast32_t K, const float *A, const float *B, float *C); +void rank_k_reduce_6x16_scheduled( + void *ctxt, int_fast32_t K, const float *A, const float *B, float *C); #define K 2048 static float A[6 * K]; @@ -31,6 +36,8 @@ int main() { clock_t start, end; int msec; + initialize(); + // Calling original matmul start = clock(); for (int i = 0; i < 1000; i++) diff --git a/examples/matmul_interp.py b/examples/matmul_interp.py new file mode 100644 index 00000000..854f016d --- /dev/null +++ b/examples/matmul_interp.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import os +import sys +import numpy as np + +from exo import proc +from exo.platforms.neon import * +from exo.stdlib.scheduling import * + +# Hide output when running through exocc. +if __name__ != "__main__" and hasattr(os, "devnull"): + sys.stdout = open(os.devnull, "w") + + +@proc +def foo(s: f32, arg: f32[1, 1] @ DRAM): + arg[0, 0] = s + + +# Algorithm definition +@proc +def rank_k_reduce_6x16( + M: size, + K: size, + N: size, + A: f32[M, K] @ DRAM, + B: f32[K, N] @ DRAM, + C: f32[M, N] @ DRAM, + test: f32 @ DRAM, +): + s: f32 + buf: f32[1, 1] + s = 4 + + for i in seq(0, M): + for j in seq(0, N): + for k in seq(0, K): + C[i, j] += A[i, k] * B[k, j] + s: f32 + s = 2 + s = s + 1 + test = s + + +@proc +def check_stride(A: [f32][6] @ DRAM, res: f32 @ DRAM): + assert stride(A, 0) == 2 + for i in seq(0, 6): + res += A[i] + + +# M = 2; K = 2; N = 2 +# A = np.zeros(M*K, dtype=float).reshape((M,K)) +# B = np.arange(K*N, dtype=float).reshape((K,N)) +# C = np.zeros(M*N, dtype=float).reshape((M,N)) +res = np.zeros(1) + +A = np.array([1.0] * 12) +# rank_k_reduce_6x16.interpret(M=M, K=K, N=N, A=A, B=B, C=C, test=res) +check_stride.interpret(A=A[::2], res=res) +print(res) diff --git a/rebuild.sh b/rebuild.sh new file mode 100755 index 00000000..784812a5 --- /dev/null +++ b/rebuild.sh @@ -0,0 +1,4 @@ +#! /bin/bash + +python -m build . +pip install --force-reinstall dist/*.whl \ No newline at end of file diff --git a/rebuild_and_test_interp.sh b/rebuild_and_test_interp.sh new file mode 100755 index 00000000..0d888de1 --- /dev/null +++ b/rebuild_and_test_interp.sh @@ -0,0 +1,5 @@ +#! /bin/bash + +python -m build . +pip install --force-reinstall dist/*.whl +python3 examples/matmul_interp.py \ No newline at end of file diff --git a/src/exo/API.py b/src/exo/API.py index 3a690ca3..5955bc5e 100644 --- a/src/exo/API.py +++ b/src/exo/API.py @@ -18,6 +18,7 @@ from .frontend.pattern_match import match_pattern from .core.prelude import * from .rewrite.new_eff import Check_Aliasing +from .backend.LoopIR_interpreter import run_interpreter # Moved to new file from .core.proc_eqv import decl_new_proc, derive_proc, assert_eqv_proc, check_eqv_proc @@ -302,6 +303,9 @@ def c_code_str(self): def compile_c(self, directory: Path, filename: str): compile_procs([self], directory, f"{filename}.c", f"{filename}.h") + def interpret(self, **kwargs): + run_interpreter(self._loopir_proc, kwargs) + # ------------------------------- # # scheduling operations # ------------------------------- # diff --git a/src/exo/backend/LoopIR_interpreter.py b/src/exo/backend/LoopIR_interpreter.py new file mode 100644 index 00000000..f4b1e819 --- /dev/null +++ b/src/exo/backend/LoopIR_interpreter.py @@ -0,0 +1,326 @@ +from collections import ChainMap, defaultdict + +import numpy as np + +from ..core.LoopIR import LoopIR +from ..core.LoopIR import T +from ..core.prelude import * + +from .parallel_analysis import ParallelAnalysis +from .prec_analysis import PrecisionAnalysis +from .win_analysis import WindowAnalysis +from .mem_analysis import MemoryAnalysis + +# --------------------------------------------------------------------------- # +# --------------------------------------------------------------------------- # +# Loop IR Interpreter + + +# method copied from Python ChainMap docs https://docs.python.org/3/library/collections.html#collections.ChainMap +# to delete items from parent maps +@extclass(ChainMap) +def __delitem__(self, key): + for mapping in self.maps: + if key in mapping: + del mapping[key] + return + raise KeyError(key) + + +def _eshape(typ, env): + return tuple(r if is_pos_int(r) else env[r] for r in typ.shape()) + + +def run_interpreter(proc, kwargs): + Interpreter(proc, kwargs) + + +class Interpreter: + def __init__(self, proc, kwargs, use_randomization=False): + if not isinstance(proc, LoopIR.proc): + raise TypeError(f"Expected {proc.name} to be of type proc") + + self.env = ChainMap() + self.use_randomization = use_randomization + self.ctxt = defaultdict(dict) + + self.eval_proc(proc, kwargs) + + def _new_scope(self): + self.env = self.env.new_child() + + def _del_scope(self): + self.env = self.env.parents + + def typecheck_input_buffer(self, proc_arg, kwargs): + nm = proc_arg.name + if not proc_arg.type.is_numeric(): + raise TypeError(f"arg {nm} is expected to be numeric") + + basetype = proc_arg.type.basetype() + buf = kwargs[str(proc_arg.name)] + + pre = f"bad argument '{nm}'" + if not isinstance(buf, np.ndarray): + raise TypeError(f"{pre}: expected numpy.ndarray") + + if isinstance(basetype, T.F32): + if buf.dtype != np.float32: + raise TypeError(f"{pre}: received {buf.dtype} values") + + if isinstance(basetype, T.F16): + if buf.dtype != np.float16: + raise TypeError(f"{pre}: received {buf.dtype} values") + + if isinstance(basetype, (T.F64, T.Num)): + if buf.dtype != np.float64: + raise TypeError(f"{pre}: received {buf.dtype} values") + + if isinstance(basetype, T.INT8): + if buf.dtype != np.int8: + raise TypeError(f"{pre}: received {buf.dtype} values") + + if isinstance(basetype, T.INT32): + if buf.dtype != np.int32: + raise TypeError(f"{pre}: received {buf.dtype} values") + + if isinstance(basetype, T.UINT8): + if buf.dtype != np.uint8: + raise TypeError(f"{pre}: received {buf.dtype} values") + + if isinstance(basetype, T.UINT16): + if buf.dtype != np.uint16: + raise TypeError(f"{pre}: received {buf.dtype} values") + + if proc_arg.type.is_real_scalar(): + if tuple(buf.shape) != (1,): + raise TypeError( + f"{pre}: expected buffer of shape (1,), " + f"but got shape {tuple(buf.shape)}" + ) + else: + shape = self.eval_shape(proc_arg.type) + if shape != tuple(buf.shape): + raise TypeError( + f"{pre}: expected buffer of shape {shape}, " + f"but got shape {tuple(buf.shape)}" + ) + + def eval_proc(self, proc, kwargs): + proc = ParallelAnalysis().run(proc) + proc = PrecisionAnalysis().run(proc) # TODO: need this? + proc = WindowAnalysis().apply_proc(proc) + proc = MemoryAnalysis().run(proc) # TODO: need this? + + for a in proc.args: + if not str(a.name) in kwargs: + raise TypeError(f"expected argument '{a.name}' to be supplied") + + if a.type is T.size: + if not is_pos_int(kwargs[str(a.name)]): + raise TypeError( + f"expected size '{a.name}' to have positive integer value" + ) + self.env[a.name] = kwargs[str(a.name)] + elif a.type is T.index: + if type(kwargs[str(a.name)]) is not int: + raise TypeError( + f"expected index variable '{a.name}' to be an integer" + ) + self.env[a.name] = kwargs[str(a.name)] + elif a.type is T.bool: + if type(kwargs[str(a.name)]) is not bool: + raise TypeError(f"expected bool variable '{a.name}' to be a bool") + self.env[a.name] = kwargs[str(a.name)] + elif a.type is T.stride: + if type(kwargs[str(a.name)]) is not int: + raise TypeError( + f"expected stride variable '{a.name}' to be an integer" + ) + self.env[a.name] = kwargs[str(a.name)] + else: + self.typecheck_input_buffer(a, kwargs) + self.env[a.name] = kwargs[str(a.name)] + + # evaluate preconditions + for pred in proc.preds: + if isinstance(pred, LoopIR.Const): + continue + else: + assert self.eval_e(pred), "precondition not satisfied" + + # eval statements + self.eval_stmts(proc.body) + + def eval_stmts(self, stmts): + for s in stmts: + self.eval_s(s) + + def eval_s(self, s): + if isinstance(s, LoopIR.Pass): + pass + elif isinstance(s, (LoopIR.Assign, LoopIR.Reduce)): + lbuf = self.env[s.name] + if len(s.idx) == 0: + # lbuf = rhs + idx = (0,) + else: + # lbuf[a0,a1,...] = rhs + idx = tuple(self.eval_e(a) for a in s.idx) + rhs = self.eval_e(s.rhs) + if isinstance(s, LoopIR.Assign): + lbuf[idx] = rhs + else: + lbuf[idx] += rhs + + elif isinstance(s, LoopIR.WriteConfig): + nm = s.config.name() + rhs = self.eval_e(s.rhs) + self.ctxt[nm][s.field] = rhs + + elif isinstance(s, LoopIR.WindowStmt): + # nm = rbuf[...] + assert s.name not in self.env, "WindowStmt should be a fresh assignment" + assert isinstance( + s.rhs, LoopIR.WindowExpr + ), "WindowStmt rhs should be WindowExpr" + self.env[s.name] = self.eval_e(s.rhs) + + elif isinstance(s, LoopIR.If): + cond = self.eval_e(s.cond) + if cond: + self._new_scope() + self.eval_stmts(s.body) + self._del_scope() + if s.orelse and not cond: + self._new_scope() + self.eval_stmts(s.orelse) + self._del_scope() + + elif isinstance(s, LoopIR.For): + # future TODO: handle loop_mode + lo = self.eval_e(s.lo) + hi = self.eval_e(s.hi) + assert self.use_randomization is False, "TODO: Implement Rand" + self._new_scope() + for itr in range(lo, hi): + self.env[s.iter] = itr + self.eval_stmts(s.body) + self._del_scope() + + elif isinstance(s, LoopIR.Alloc): + if s.type.is_real_scalar(): + self.env[s.name] = np.empty([1]) + else: + size = self.eval_shape(s.type) + # TODO: Maybe randomize? + self.env[s.name] = np.empty(size) + + elif isinstance(s, LoopIR.Free): + # use extension to chain map from python docs + del self.env[s.name] + + elif isinstance(s, LoopIR.Call): + argvals = [self.eval_e(a, call_arg=True) for a in s.args] + argnames = [str(a.name) for a in s.f.args] + kwargs = {nm: val for nm, val in zip(argnames, argvals)} + self._new_scope() + self.eval_proc(s.f, kwargs) + self._del_scope() + + else: + assert False, "bad statement case" + + def eval_e(self, e, call_arg=False): + + if isinstance(e, LoopIR.Read): + buf = self.env[e.name] + if call_arg or isinstance(buf, (int, bool)): + # read without indices + return buf + else: + idx = (0,) if len(e.idx) == 0 else tuple(self.eval_e(a) for a in e.idx) + return buf[idx] + + elif isinstance(e, LoopIR.WindowExpr): + buf = self.env[e.name] + + def stringify_w_access(a): + if isinstance(a, LoopIR.Interval): + return f"{self.eval_e(a.lo)}:{self.eval_e(a.hi)}" + elif isinstance(a, LoopIR.Point): + return f"{self.eval_e(a.pt)}" + else: + assert False, "bad w_access case" + + # hack to handle interval indexes: LoopIR.Interval returns a string representing the interval + idx = ( + ("0",) + if len(e.idx) == 0 + else tuple(stringify_w_access(a) for a in e.idx) + ) + res = eval(f"buf[{','.join(idx)}]") + return res + + elif isinstance(e, LoopIR.Const): + return e.val + + elif isinstance(e, LoopIR.BinOp): + lhs, rhs = self.eval_e(e.lhs), self.eval_e(e.rhs) + if e.op == "+": + return lhs + rhs + elif e.op == "-": + return lhs - rhs + elif e.op == "*": + return lhs * rhs + elif e.op == "/": + if isinstance(lhs, int) and isinstance(rhs, int): + # this is what was here before and without the rhs check + # counter example of why this is wrong -3 / 2 == -1 in C and 0 in this impl + # return (lhs + rhs - 1) // rhs + return int(lhs / rhs) + else: + return lhs / rhs + elif e.op == "%": + return lhs % rhs + elif e.op == "==": + return lhs == rhs + elif e.op == "<": + return lhs < rhs + elif e.op == ">": + return lhs > rhs + elif e.op == "<=": + return lhs <= rhs + elif e.op == ">=": + return lhs >= rhs + elif e.op == "and": + return lhs and rhs + elif e.op == "or": + return lhs or rhs + + elif isinstance(e, LoopIR.USub): + return -self.eval_e(e.arg) + + # BuiltIns don't go to the interpreter, they are just called (via call) like a proc + # TODO Discuss to make sure + # elif isinstance(e, LoopIR.BuiltIn): + # assert False, "Not implemented" + # args = [self.eval_e(a) for a in e.args] + # return e.f.interpret(args) + + elif isinstance(e, LoopIR.StrideExpr): + buf = self.env[e.name] + assert e.dim < len(buf.strides), "invalid dim in stride expression" + # grammar guarantees int (not an expression) + return int(buf.strides[e.dim] / buf.dtype.itemsize) + + elif isinstance(e, LoopIR.ReadConfig): + nm = e.config.name() + return self.ctxt[nm][e.field] + + else: + print(e) + assert False, "bad expression case" + + def eval_shape(self, typ): + return tuple(self.eval_e(s) for s in typ.shape()) diff --git a/src/exo/backend/parallel_analysis.py b/src/exo/backend/parallel_analysis.py index f82ae1ae..2ff725e5 100644 --- a/src/exo/backend/parallel_analysis.py +++ b/src/exo/backend/parallel_analysis.py @@ -13,7 +13,7 @@ def run(self, proc): proc = super().apply_proc(proc) if self._errors: errs = "\n".join(self._errors) - raise TypeError(f"Errors occurred during precision checking:\n{errs}") + raise TypeError(f"Errors occurred during parallel analysis:\n{errs}") return proc def err(self, node, msg): diff --git a/tests/test_interp.py b/tests/test_interp.py new file mode 100644 index 00000000..4096fc7a --- /dev/null +++ b/tests/test_interp.py @@ -0,0 +1,501 @@ +from __future__ import annotations + +import pytest + +import numpy as np + +from exo import proc, config, instr +from exo.libs.memories import GEMM_SCRATCH +from exo.stdlib.scheduling import SchedulingError + +# ------- Interpreter tests --------- + + +def test_mat_mul(compiler): + @proc + def rank_k_reduce( + K: size, + A: f32[6, K], + B: f32[K, 16], + C: f32[6, 16], + ): + for i in seq(0, 6): + for j in seq(0, 16): + for k in seq(0, K): + C[i, j] += A[i, k] * B[k, j] + + fn = compiler.compile(rank_k_reduce) + + K = 8 + A = np.arange(6 * K, dtype=np.float32).reshape((6, K)) + B = np.arange(K * 16, dtype=np.float32).reshape((K, 16)) + C1 = np.zeros(6 * 16, dtype=np.float32).reshape((6, 16)) + C2 = np.zeros(6 * 16, dtype=np.float32).reshape((6, 16)) + + fn(None, K, A, B, C1) + rank_k_reduce.interpret(K=K, A=A, B=B, C=C2) + assert (C1 == C2).all() + + +def test_reduce_add(compiler): + @proc + def acc(N: size, A: f32[N], acc: f32): + acc = 0 + for i in seq(0, N): + acc += A[i] + + fn = compiler.compile(acc) + + n = 3 + A = np.arange(n, dtype=np.float32) + x = np.zeros(1, dtype=np.float32) + y = np.zeros(1, dtype=np.float32) + + fn(None, n, A, x) + acc.interpret(N=n, A=A, acc=y) + assert x == y + + +def test_scope1(compiler): + @proc + def foo(res: f32): + a: f32 + a = 1 + for i in seq(0, 4): + a: f32 + a = 2 + res = a + + fn = compiler.compile(foo) + + x = np.zeros(1, dtype=np.float32) + y = np.zeros(1, dtype=np.float32) + + fn(None, x) + foo.interpret(res=y) + assert x == y + + +def test_scope2(compiler): + @proc + def foo(res: f32): + a: f32 + a = 1 + for i in seq(0, 4): + a = 2 + res = a + + fn = compiler.compile(foo) + + x = np.zeros(1, dtype=np.float32) + y = np.zeros(1, dtype=np.float32) + + fn(None, x) + foo.interpret(res=y) + assert x == y + + +def test_empty_seq(compiler): + @proc + def foo(res: f32): + for i in seq(0, 0): + res = 1 + + fn = compiler.compile(foo) + + x = np.zeros(1, dtype=np.float32) + y = np.zeros(1, dtype=np.float32) + + fn(None, x) + foo.interpret(res=y) + assert x == y + + +def test_cond(compiler): + @proc + def foo(res: f32, p: bool): + if p: + res = 1 + else: + res = 2 + + fn = compiler.compile(foo) + + x = np.zeros(1, dtype=np.float32) + y = np.zeros(1, dtype=np.float32) + + fn(None, x, False) + foo.interpret(res=y, p=False) + assert x == y + + +def test_call(compiler): + @proc + def bar(res: f32): + res = 3 + + @proc + def foo(res: f32): + res = 2 + bar(res) + res += 1 + + fn = compiler.compile(foo) + + x = np.zeros(1, dtype=np.float32) + y = np.zeros(1, dtype=np.float32) + + fn(None, x) + foo.interpret(res=y) + assert x == y + + +def test_window_assert(compiler): + @proc + def foo( + n: size, + m: size, + src: [f32][n, m], + dst: [f32][n, 16], + ): + assert n <= 16 + assert m <= 16 + + for i in seq(0, n): + for j in seq(0, m): + dst[i, j] = src[i, j] + + n = 6 + m = 8 + src = np.arange(n * m, dtype=np.float32).reshape((n, m)) + dst = np.zeros(n * 16, dtype=np.float32).reshape((n, 16)) + + foo.interpret(n=n, m=m, src=src, dst=dst) + assert (dst[:, :8] == src).all() + + +def test_window_stmt1(compiler): + @proc + def foo(n: size, A: f32[n, 16], C: f32[n]): + B = A[:, 0] + for i in seq(0, n): + C[i] = B[i] + + fn = compiler.compile(foo) + + n = 6 + A = np.arange(n * 16, dtype=np.float32).reshape((n, 16)) + C1 = np.arange(n, dtype=np.float32) + C2 = np.arange(n, dtype=np.float32) + + fn(None, n, A, C1) + foo.interpret(n=n, A=A, C=C2) + + assert (C1 == C2).all() + + +def test_window_stmt2(compiler): + @proc + def foo(n: size, A: f32[n], B: f32[n], C: f32[2 * n]): + for i in seq(0, n): + C[i] = A[i] + for i in seq(n, 2 * n): + C[i] = B[i - n] + + fn = compiler.compile(foo) + + n = 6 + A = np.arange(n, dtype=np.float32) + B = np.arange(n, dtype=np.float32) + C1 = np.zeros(2 * n, dtype=np.float32) + C2 = np.zeros(2 * n, dtype=np.float32) + + fn(None, n, A, B, C1) + foo.interpret(n=n, A=A, B=B, C=C2) + assert (C1 == C2).all() + + +def test_window_stmt3(compiler): + @proc + def foo(A: f32[8], res: f32): + B = A[4:] + res = B[0] + + fn = compiler.compile(foo) + + A = np.arange(8, dtype=np.float32) + x = np.zeros(1, dtype=np.float32) + y = np.zeros(1, dtype=np.float32) + + fn(None, A, x) + foo.interpret(A=A, res=y) + assert x[0] == 4 and x == y + + +# TODO: discuss +# error can be better here +def test_window_stmt4(compiler): + @proc + def foo(A: f32[8], C: [f32][4]): + B = A[4:] + C = B[:] + + +def test_stride_simple1(compiler): + @proc + def bar(s0: stride, s1: stride, B: [i8][3, 4]): + assert stride(B, 0) == s0 + assert stride(B, 1) == s1 + pass + + @proc + def foo(A: i8[3, 4]): + bar(stride(A, 0), stride(A, 1), A[:, :]) + + fn = compiler.compile(foo) + + A = np.arange(3 * 4, dtype=np.int8).reshape((3, 4)) + + fn(None, A) + foo.interpret(A=A) + + +def test_stride_simple2(compiler): + @proc + def bar(s0: stride, s1: stride, B: [i8][1, 1]): + assert stride(B, 0) == s0 + assert stride(B, 1) == s1 + pass + + @proc + def foo(A: [i8][3, 4]): + bar(stride(A, 0), stride(A, 1), A[0:1, 1:2]) + + fn = compiler.compile(foo) + + A = np.arange(6 * 8, dtype=np.int8).reshape((6, 8)) + + fn(None, A[::2, ::2]) + foo.interpret(A=A[::2, ::2]) + + +def test_stride1(compiler): + @proc + def foo(A: [i8][3, 2, 3]): + assert stride(A, 0) == 20 + assert stride(A, 1) == 5 * 2 + assert stride(A, 2) == 1 * 2 + pass + + fn = compiler.compile(foo) + + A = np.arange(3 * 4 * 5, dtype=np.int8).reshape((3, 4, 5)) + + fn(None, A[::1, ::2, ::2]) + foo.interpret(A=A[::1, ::2, ::2]) + + +def test_stride2(compiler): + @proc + def foo(A: [i8][2, 4, 2]): + assert stride(A, 0) == 20 * 2 + assert stride(A, 1) == 5 * 1 + assert stride(A, 2) == 1 * 3 + pass + + fn = compiler.compile(foo) + + A = np.arange(3 * 4 * 5, dtype=np.int8).reshape((3, 4, 5)) + + fn(None, A[::2, ::1, ::3]) + foo.interpret(A=A[::2, ::1, ::3]) + + +# TODO: discuss +# updating param within stride conditional triggers validation error +def test_branch_stride1(compiler): + @proc + def bar(B: [i8][3, 4], res: f32): + if stride(B, 0) == 8: + res = 1 + + @proc + def foo(A: i8[3, 4], res: f32): + bar(A[:, :], res) + + +# but this is okay: +def test_branch_stride2(compiler): + @proc + def bar(B: [i8][3, 4], res: f32): + if stride(B, 0) == 8: + res = 1 + + @proc + def foo(A: i8[3, 4], res: f32): + bar(A, res) + + +# so is this +def test_branch_stride3(compiler): + @proc + def bar(B: [i8][3, 4], res: f32): + a: f32 + a = 0 + if stride(B, 0) == 8: + a = 1 + res = a + + @proc + def foo(A: i8[3, 4], res: f32): + bar(A[:, :], res) + + +def test_bounds_err_interp(): + with pytest.raises(TypeError): + + @proc + def foo(N: size, A: f32[N], res: f32): + a: f32 + res = A[3] + + N = 2 + A = np.arange(N, dtype=np.float32) + x = np.zeros(1, dtype=np.float32) + + foo.interpret(N=N, A=A, res=x) + + +def test_precond_interp_simple(): + with pytest.raises(AssertionError): + + @proc + def foo(N: size, A: f32[N], res: f32): + assert N == 4 + res = A[3] + + N = 2 + A = np.arange(N, dtype=np.float32) + x = np.zeros(1, dtype=np.float32) + + foo.interpret(N=N, A=A, res=x) + + +def test_precond_interp_stride(): + with pytest.raises(AssertionError): + + @proc + def foo(A: f32[1, 8]): + assert stride(A, 0) == 8 + pass + + A = np.arange(16, dtype=np.float32).reshape((1, 16)) + foo.interpret(A=A[:, ::2]) + + +def new_config(): + @config + class Config: + a: f32 + b: f32 + + return Config + + +def test_config(compiler): + Config = new_config() + + @proc + def foo(x: f32): + Config.a = 32.0 + x = Config.a + + fn = compiler.compile(foo) + + x = np.zeros(1, dtype=np.float32) + foo.interpret(x=x) + assert x == 32.0 + + +def test_config_nested(compiler): + Config = new_config() + + @proc + def bar(x: f32): + x = Config.a + Config.b + + @proc + def foo(x: f32): + Config.a = 32.0 + Config.b = 16.0 + bar(x) + + fn = compiler.compile(foo) + + x = np.zeros(1, dtype=np.float32) + foo.interpret(x=x) + assert x == 48.0 + + +def test_par_bad(): + with pytest.raises(TypeError): + + @proc + def foo(x: f32[10], acc: f32): + for i in par(0, 10): + acc += x[i] + + x = np.arange(10, dtype=np.float32) + a = np.zeros(1, dtype=np.float32) + + foo.interpret(x=x, acc=a) + + +def test_par_good(): + @proc + def foo(x: f32[10]): + for i in par(0, 10): + x[i] = 1 + + x = np.zeros(10, dtype=np.float32) + + foo.interpret(x=x) + assert (x == np.ones(10, dtype=np.float32)).all() + + +def test_built_in(): + @instr("") + def four_wide_vector_add(m: size, A: [f64][m], B: [f64][m], C: [f64][m]): + assert m >= 4 + for i in seq(0, 4): + C[i] = A[i] + B[i] + + @proc + def dumb_vector_add(n: size, A: f64[n], B: f64[n], C: f64[n]): + assert n >= 5 + four_wide_vector_add(n - 1, A[1:], B[1:], C[1:]) + + @proc + def slightly_smarter_vector_add(n: size, A: f64[n], B: f64[n], C: f64[n]): + assert (n % 4) == 0 + assert n >= 8 + for j in seq(0, n / 4): + four_wide_vector_add( + 4, + A[j * 4 : (j * 4) + 4], + B[j * 4 : (j * 4) + 4], + C[j * 4 : (j * 4) + 4], + ) + + A = np.array([1] * 5, dtype=np.float64) + B = np.array([2] * 5, dtype=np.float64) + C = np.zeros(5, dtype=np.float64) + + dumb_vector_add.interpret(n=5, A=A, B=B, C=C) + assert (C == np.array([0, 3, 3, 3, 3], dtype=np.float64)).all() + + A = np.array([1] * 8, dtype=np.float64) + B = np.array([2] * 8, dtype=np.float64) + C = np.zeros(8, dtype=np.float64) + slightly_smarter_vector_add.interpret(n=8, A=A, B=B, C=C) + assert (C == np.array([3] * 8, dtype=np.float64)).all()