diff --git a/src/exo/API_cursors.py b/src/exo/API_cursors.py index 94869d665..0c00d3b41 100644 --- a/src/exo/API_cursors.py +++ b/src/exo/API_cursors.py @@ -581,6 +581,8 @@ def loopir_type_to_exotype(typ: LoopIR.Type) -> API.ExoType: return API.ExoType.F64 elif isinstance(typ, LoopIR.INT8): return API.ExoType.I8 + elif isinstance(typ, LoopIR.INT16): + return API.ExoType.I16 elif isinstance(typ, LoopIR.INT32): return API.ExoType.I32 elif isinstance(typ, LoopIR.Bool): diff --git a/src/exo/API_scheduling.py b/src/exo/API_scheduling.py index 948532490..103ec3708 100644 --- a/src/exo/API_scheduling.py +++ b/src/exo/API_scheduling.py @@ -206,7 +206,6 @@ def __call__(self, bval, all_args): self.err("expected a bool") return bval - class OptionalA(ArgumentProcessor): def __init__(self, arg_proc): if is_subclass_obj(arg_proc, ArgumentProcessor): @@ -316,6 +315,7 @@ class TypeAbbrevA(ArgumentProcessor): "f32": T.f32, "f64": T.f64, "i8": T.int8, + "i16": T.int16, "ui8": T.uint8, "ui16": T.uint16, "i32": T.int32, @@ -1518,8 +1518,8 @@ def stage_window(proc, expr_cursor, win_name, memory=None): return scheduling.DoStageWindow(proc, win_name, memory, e).result() -@sched_op([BlockCursorA, CustomWindowExprA("block_cursor"), NameA, BoolA]) -def stage_mem(proc, block_cursor, win_expr, new_buf_name, accum=False): +@sched_op([BlockCursorA, CustomWindowExprA("block_cursor"), NameA, BoolA, BoolA]) +def stage_mem(proc, block_cursor, win_expr, new_buf_name, accum=False, init_zero=False): """ Stage the window of memory specified by `win_expr` into a new buffer before the indicated code block and move the memory back after the @@ -1568,7 +1568,7 @@ def stage_mem(proc, block_cursor, win_expr, new_buf_name, accum=False): """ buf_name, w_exprs = win_expr ir, fwd = scheduling.DoStageMem( - block_cursor._impl, buf_name, w_exprs, new_buf_name, use_accum_zero=accum + block_cursor._impl, buf_name, w_exprs, new_buf_name, use_accum_zero=accum, init_zero=init_zero ) return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) diff --git a/src/exo/LoopIR.py b/src/exo/LoopIR.py index 88c83af2a..b064e6523 100644 --- a/src/exo/LoopIR.py +++ b/src/exo/LoopIR.py @@ -107,6 +107,7 @@ def __new__(cls, op): | F64() | INT8() | UINT8() + | INT16() | UINT16() | INT32() | Bool() @@ -140,6 +141,7 @@ def __new__(cls, op): "F32", "F64", "INT8", + "INT16", "UINT16", "UINT8", "INT32" "Bool", @@ -203,6 +205,7 @@ def __new__(cls, op): | F64 () | INT8 () | UINT8 () + | INT16 () | UINT16 () | INT32 () | Bool () @@ -228,6 +231,7 @@ def __new__(cls, op): "F32", "F64", "INT8", + "INT16", "UINT8", "UINT16", "INT32", @@ -363,6 +367,7 @@ def __new__(cls, op): @extclass(UAST.F32) @extclass(UAST.F64) @extclass(UAST.INT8) +@extclass(UAST.INT16) @extclass(UAST.UINT8) @extclass(UAST.UINT16) @extclass(UAST.INT32) @@ -404,6 +409,7 @@ class T: F32 = LoopIR.F32 F64 = LoopIR.F64 INT8 = LoopIR.INT8 + INT16 = LoopIR.INT16 UINT8 = LoopIR.UINT8 UINT16 = LoopIR.UINT16 INT32 = LoopIR.INT32 @@ -420,9 +426,11 @@ class T: f16 = F16() f32 = F32() int8 = INT8() + int16 = INT16() uint8 = UINT8() uint16 = UINT16() i8 = INT8() + i16 = INT16() ui8 = UINT8() ui16 = UINT16() int32 = INT32() @@ -447,6 +455,7 @@ class T: @extclass(T.F32) @extclass(T.F64) @extclass(T.INT8) +@extclass(T.INT16) @extclass(T.UINT8) @extclass(T.UINT16) @extclass(T.INT32) @@ -468,6 +477,7 @@ def shape(t): @extclass(T.F32) @extclass(T.F64) @extclass(T.INT8) +@extclass(T.INT16) @extclass(T.UINT8) @extclass(T.UINT16) @extclass(T.INT32) @@ -487,6 +497,8 @@ def ctype(t): return "double" elif isinstance(t, T.INT8): return "int8_t" + elif isinstance(t, T.INT16): + return "int16_t" elif isinstance(t, T.UINT8): return "uint8_t" elif isinstance(t, T.UINT16): @@ -505,7 +517,7 @@ def ctype(t): @extclass(LoopIR.type) def is_real_scalar(t): return isinstance( - t, (T.Num, T.F16, T.F32, T.F64, T.INT8, T.UINT8, T.UINT16, T.INT32) + t, (T.Num, T.F16, T.F32, T.F64, T.INT8, T.INT16, T.UINT8, T.UINT16, T.INT32) ) diff --git a/src/exo/LoopIR_compiler.py b/src/exo/LoopIR_compiler.py index ff34b7154..7f687113f 100644 --- a/src/exo/LoopIR_compiler.py +++ b/src/exo/LoopIR_compiler.py @@ -285,7 +285,7 @@ def _window_struct(typename, ctype, n_dims, is_const) -> WindowStruct: f" const int_fast32_t strides[{n_dims}];\n" f"}};" ) - + #sdef = ("") #ADRIAN return WindowStruct(sname, sdef) @@ -297,6 +297,7 @@ def window_struct(base_type, n_dims, is_const) -> WindowStruct: T.f32: "f32", T.f64: "f64", T.i8: "i8", + T.i16: "i16", T.ui8: "ui8", T.ui16: "ui16", T.i32: "i32", @@ -740,6 +741,7 @@ def access_str(self, nm, idx_list) -> str: if not type.is_win(): return f"{buf}[{idx_expr_s}]" else: + #return f"{buf}[{idx_expr_s}]" return f"{buf}.data[{idx_expr_s}]" def shape_strs(self, shape, prec=100) -> str: @@ -962,6 +964,7 @@ def comp_fnarg(self, e, fn, i, *, prec=0): win_struct = self.get_window_type(e.type, is_const) data, strides = self.window_struct_fields(e) return f"(struct {win_struct}){{ &{data}, {{ {strides} }} }}" + #return f"&{data}, {strides}" else: return self.comp_e(e, prec) @@ -990,6 +993,7 @@ def comp_e(self, e, prec=0): win_struct = self.get_window_type(e.type) data, strides = self.window_struct_fields(e) return f"(struct {win_struct}){{ &{data}, {{ {strides} }} }}" + #return f"&{data}, {strides}" elif isinstance(e, LoopIR.Const): if isinstance(e.val, bool): diff --git a/src/exo/LoopIR_pprint.py b/src/exo/LoopIR_pprint.py index 807a441c2..83621bb9c 100644 --- a/src/exo/LoopIR_pprint.py +++ b/src/exo/LoopIR_pprint.py @@ -292,6 +292,8 @@ def ptype(self, t): return "f64" elif isinstance(t, UAST.INT8): return "i8" + elif isinstance(t, UAST.INT16): + return "i16" elif isinstance(t, UAST.UINT8): return "ui8" elif isinstance(t, UAST.UINT16): @@ -530,6 +532,8 @@ def _print_type(t, env: PrintEnv) -> str: return "f64" elif isinstance(t, T.INT8): return "i8" + elif isinstance(t, T.INT16): + return "i16" elif isinstance(t, T.UINT8): return "ui8" elif isinstance(t, T.UINT16): diff --git a/src/exo/LoopIR_scheduling.py b/src/exo/LoopIR_scheduling.py index 388c576f9..60081b64a 100644 --- a/src/exo/LoopIR_scheduling.py +++ b/src/exo/LoopIR_scheduling.py @@ -3667,7 +3667,7 @@ def do_e(self, e): pass -def DoStageMem(block_cursor, buf_name, w_exprs, new_name, use_accum_zero=False): +def DoStageMem(block_cursor, buf_name, w_exprs, new_name, use_accum_zero=False, init_zero=False): proc = block_cursor.get_root() new_name = Sym(new_name) @@ -3799,8 +3799,10 @@ def guard_wrapper(body): ) else: load_ridx.append(w) - load_rhs = LoopIR.Read(buf_name, load_ridx, basetyp, srcinfo) - + if init_zero == False: + load_rhs = LoopIR.Read(buf_name, load_ridx, basetyp, srcinfo) + else: + load_rhs = LoopIR.Const(0.0, basetyp, srcinfo) load_nest = [ LoopIR.Assign(new_name, basetyp, None, load_widx, load_rhs, None, srcinfo) ] @@ -3822,9 +3824,10 @@ def guard_wrapper(body): if not use_accum_zero: load_nest_c = fwd(block_cursor[0]).prev() - ir, fwd = insert_safety_guards( - ir, fwd, get_inner_stmt(load_nest_c), load_rhs, buf_typ - ) + if init_zero == False: + ir, fwd = insert_safety_guards( + ir, fwd, get_inner_stmt(load_nest_c), load_rhs, buf_typ + ) if isW: store_iter = [Sym(f"i{i}") for i, _ in enumerate(shape)] store_ridx = [LoopIR.Read(s, [], T.index, srcinfo) for s in store_iter] diff --git a/src/exo/configs.py b/src/exo/configs.py index c139f3338..1bb7b0fd9 100644 --- a/src/exo/configs.py +++ b/src/exo/configs.py @@ -28,6 +28,7 @@ def new_config(name, fields, disable_rw=False): 'f32' : LoopIR.T.f32, 'f64' : LoopIR.T.f64, 'i8' : LoopIR.T.i8, + 'i16' : LoopIR.T.i16, 'i32' : LoopIR.T.i32, } good_args = (isinstance(name, str) and @@ -61,6 +62,8 @@ def ctyp(typ): return "double" elif isinstance(typ, LoopIR.T.INT8): return "int8_t" + elif isinstance(typ, LoopIR.T.INT16): + return "int16_t" elif isinstance(typ, LoopIR.T.INT32): return "int32_t" elif isinstance(typ, LoopIR.T.Bool): @@ -97,6 +100,7 @@ def __init__(self, name, fields, disable_rw): LoopIR.UAST.F32(): LoopIR.T.f32, LoopIR.UAST.F64(): LoopIR.T.f64, LoopIR.UAST.INT8(): LoopIR.T.i8, + LoopIR.UAST.INT16(): LoopIR.T.i16, LoopIR.UAST.INT32(): LoopIR.T.i32, } diff --git a/src/exo/memory.py b/src/exo/memory.py index c5d1a1770..f4bb86cb5 100644 --- a/src/exo/memory.py +++ b/src/exo/memory.py @@ -109,6 +109,7 @@ def window(cls, basetyp, baseptr, indices, strides, srcinfo): if basetyp.is_win(): baseptr = f"{baseptr}.data" + #baseptr = f"{baseptr}" return f"{baseptr}[{offset}]" diff --git a/src/exo/platforms/neon.py b/src/exo/platforms/neon.py index 0b449afb7..65bf93c6a 100644 --- a/src/exo/platforms/neon.py +++ b/src/exo/platforms/neon.py @@ -31,6 +31,9 @@ def alloc(cls, new_name, prim_type, shape, srcinfo): raise MemGenError(f"{srcinfo}: Neon vectors are not scalar values") vec_types = { + "int8_t": (8, "int8x8_t"), + "int16_t": (4, "int16x4_t"), + "int32_t": (4, "int32x4_t"), "float": (4, "float32x4_t"), "double": (2, "float64x2_t"), "_Float16": (8, "float16x8_t"), @@ -255,7 +258,7 @@ def neon_vfmadd_1xf32_4xf32( # float16 -@instr("{dst_data} = vld1q_f16((float16_t *)&{src_data});") +@instr("{dst_data} = vld1q_f16(&{src_data});") def neon_vld_8xf16(dst: [f16][8] @ Neon, src: [f16][8] @ DRAM): assert stride(src, 0) == 1 assert stride(dst, 0) == 1 @@ -264,7 +267,7 @@ def neon_vld_8xf16(dst: [f16][8] @ Neon, src: [f16][8] @ DRAM): dst[i] = src[i] -@instr("vst1q_f16((float16_t *)&{dst_data}, {src_data});") +@instr("vst1q_f16(&{dst_data}, {src_data});") def neon_vst_8xf16(dst: [f16][8] @ DRAM, src: [f16][8] @ Neon): assert stride(src, 0) == 1 assert stride(dst, 0) == 1 @@ -273,7 +276,7 @@ def neon_vst_8xf16(dst: [f16][8] @ DRAM, src: [f16][8] @ Neon): dst[i] = src[i] -@instr("{dst_data} = vld1q_dup_f16((float16_t *)&{src_data});") +@instr("{dst_data} = vld1q_dup_f16((&{src_data});") def neon_broadcast_8xf16(dst: [f16][8] @ Neon, src: [f16][1] @ DRAM): assert stride(dst, 0) == 1 @@ -281,7 +284,7 @@ def neon_broadcast_8xf16(dst: [f16][8] @ Neon, src: [f16][1] @ DRAM): dst[i] = src[0] -@instr("{dst_data} = vmovq_n_f16(0.0f);") +@instr("{dst_data} = vmovq_n_f16(_Float16)0.0f);") def neon_zero_8xf16(dst: [f16][8] @ Neon): assert stride(dst, 0) == 1 @@ -526,3 +529,95 @@ def neon_convert_f32_upper_to_f64(dst: [f64][2] @ Neon, src: [f32][4] @ Neon): for i in seq(0, 2): dst[i] = src[2 + i] + +# --------------------------------------------------------------------------- # +# mixed integer precision +# just the minimum needed +# --------------------------------------------------------------------------- # +@instr("{dst_data} = vmovl_s8({src_data})") +def neon_convert_i8_to_i16(dst: [i16][8] @ Neon, src: [i8][8] @ Neon): + assert stride(dst, 0) == 1 + assert stride(src, 0) == 1 + + for i in seq(0, 8): + dst[i] = src[i] + +@instr("{dst_data} = vld1_s8(&{src_data});") +def neon_vld_8xi8(dst: [i8][8] @ Neon, src: [i8][8] @ DRAM, e: index): + assert stride(src, 0) == 1 + assert stride(dst, 0) == 1 + assert e >= 0 + assert e < 8 + + for i in seq(0, e): + dst[i] = src[i] + +@instr("vst1q_s32(&{dst_data}, {src_data});") +def neon_vst_4xi32(dst: [i32][4] @ DRAM, src: [i32][4] @ Neon): + assert stride(src, 0) == 1 + assert stride(dst, 0) == 1 + + for i in seq(0, 4): + dst[i] = src[i] + +@instr("{dst_data} = vld1q_s32(&{src_data});") +def neon_vld_4xi32(dst: [i32][4] @ Neon, src: [i32][4] @ DRAM): + assert stride(src, 0) == 1 + assert stride(dst, 0) == 1 + + for i in seq(0, 4): + dst[i] = src[i] + + +@instr("{dst_data} = vmovq_n_s32(0);") +def neon_zero_4xi32(dst: [i32][4] @ Neon): + assert stride(dst, 0) == 1 + + for i in seq(0, 4): + dst[i] = 0.0 + + + +@instr("{dst_data} = vget_low_s16(vmovl_s8({src_data}));") +def neon_get_low_8xi16(dst: [i16][4] @ Neon, src: [i8][8] @ Neon): + assert stride(src, 0) == 1 + assert stride(dst, 0) == 1 + + for i in seq(0, 4): + dst[i] = src[i] + +@instr("{dst_data} = vget_high_s16(vmovl_s8({src_data}));") +def neon_get_high_8xi16(dst: [i16][4] @ Neon, src: [i8][8] @ Neon): + assert stride(src, 0) == 1 + assert stride(dst, 0) == 1 + + for i in seq(0, 4): + dst[i] = src[i+4] + +@instr("{dst_data} = vmlal_lane_s16({dst_data}, {lhs_data}, {rhs_data}, {jtt});") +def neon_vmlal_8xi16_8xi16( + dst: [i32][4] @ Neon, lhs: [i16][4] @ Neon, rhs: [i16][1,4] @ Neon, + jtt:index, + ): + assert stride(dst, 0) == 1 + assert stride(lhs, 0) == 1 + assert stride(rhs, 0) == 1 + assert jtt >= 0 + assert jtt < 4 + + for i in seq(0, 4): + dst[i] += lhs[i] * rhs[0,jtt] + +@instr("{dst_data} = vmlal_lane_s16({dst_data}, {lhs_data}, {rhs_data}, {jtt});") +def neon_vmlal_sing_8xi16_8xi16( + dst: [i32][4] @ Neon, lhs: [i16][4] @ Neon, rhs: [i16][4] @ Neon, + jtt:index, + ): + assert stride(dst, 0) == 1 + assert stride(lhs, 0) == 1 + assert stride(rhs, 0) == 1 + assert jtt >= 0 + assert jtt < 4 + + for i in seq(0, 4): + dst[i] += lhs[i] * rhs[jtt] diff --git a/src/exo/platforms/rvv.py b/src/exo/platforms/rvv.py index 72bb883a7..947d6e580 100644 --- a/src/exo/platforms/rvv.py +++ b/src/exo/platforms/rvv.py @@ -1,7 +1,7 @@ from __future__ import annotations from exo import Memory, DRAM, instr - +import os def _is_const_size(sz, c): return sz.isdecimal() and int(sz) == c @@ -15,7 +15,6 @@ def _is_some_const_size(sz): # Neon registers # --------------------------------------------------------------------------- # - class RVV(Memory): @classmethod def global_(cls): @@ -29,13 +28,18 @@ def can_read(cls): def alloc(cls, new_name, prim_type, shape, srcinfo): if not shape: raise MemGenError(f"{srcinfo}: RVV vectors are not scalar values") - + factor = 1 + try: + if int(os.environ['RVV_BITS']) > 0: + factor = int(os.environ['RVV_BITS'])/128 + except: + factor = 1 + vec_types = { - "float": (4, "vfloat32m1_t") - } # , "double": (2, "float64x2_t"), "_Float16" : (8, "float16x8_t")} - + "float": (4*factor, "vfloat32m1_t"), "double": (2*factor, "vfloat64m1_t"), "_Float16" : (8*factor, "vfloat16m1_t")} + if not prim_type in vec_types.keys(): - raise MemGenError(f"{srcinfo}: RVV vectors must be f32 (for now)") + raise MemGenError(f"{srcinfo}: RVV vectors must be floats (for now)") reg_width, C_reg_type_name = vec_types[prim_type] @@ -92,6 +96,18 @@ def rvv_vld_4xf32(dst: [f32][4] @ RVV, src: [f32][4] @ DRAM, vl: size): dst[i] = src[i] +@instr("{dst_data} = __riscv_vle32_v_f32m1(&{src_data},{vl});") +def rvv_vld_8xf32(dst: [f32][8] @ RVV, src: [f32][8] @ DRAM, vl: size): + assert stride(src, 0) == 1 + assert stride(dst, 0) == 1 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] = src[i] + + + @instr("__riscv_vse32_v_f32m1(&{dst_data}, {src_data},{vl});") def rvv_vst_4xf32(dst: [f32][4] @ DRAM, src: [f32][4] @ RVV, vl: size): assert stride(src, 0) == 1 @@ -102,6 +118,16 @@ def rvv_vst_4xf32(dst: [f32][4] @ DRAM, src: [f32][4] @ RVV, vl: size): for i in seq(0, vl): dst[i] = src[i] +@instr("__riscv_vse32_v_f32m1(&{dst_data}, {src_data},{vl});") +def rvv_vst_8xf32(dst: [f32][8] @ DRAM, src: [f32][8] @ RVV, vl: size): + assert stride(src, 0) == 1 + assert stride(dst, 0) == 1 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] = src[i] + @instr("{dst_data} = __riscv_vfmv_v_f_f32m1({src_data},{vl});") def rvv_broadcast_4xf32(dst: [f32][4] @ RVV, src: [f32][1] @ DRAM, vl: size): @@ -112,6 +138,14 @@ def rvv_broadcast_4xf32(dst: [f32][4] @ RVV, src: [f32][1] @ DRAM, vl: size): for i in seq(0, vl): dst[i] = src[0] +@instr("{dst_data} = __riscv_vfmv_v_f_f32m1({src_data},{vl});") +def rvv_broadcast_8xf32(dst: [f32][8] @ RVV, src: [f32][1] @ DRAM, vl: size): + assert stride(dst, 0) == 1 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] = src[0] @instr("{dst_data} = __riscv_vfmv_v_f_f32m1({src_data},{vl});") def rvv_broadcast_4xf32_scalar(dst: [f32][4] @ RVV, src: f32 @ DRAM, vl: size): @@ -122,6 +156,14 @@ def rvv_broadcast_4xf32_scalar(dst: [f32][4] @ RVV, src: f32 @ DRAM, vl: size): for i in seq(0, vl): dst[i] = src +@instr("{dst_data} = __riscv_vfmv_v_f_f32m1({src_data},{vl});") +def rvv_broadcast_8xf32_scalar(dst: [f32][8] @ RVV, src: f32 @ DRAM, vl: size): + assert stride(dst, 0) == 1 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] = src @instr("{dst_data} = __riscv_vfmv_v_f_f32m1(0.0f,{vl});") def rvv_broadcast_4xf32_0(dst: [f32][4] @ RVV, vl: size): @@ -132,6 +174,14 @@ def rvv_broadcast_4xf32_0(dst: [f32][4] @ RVV, vl: size): for i in seq(0, vl): dst[i] = 0.0 +@instr("{dst_data} = __riscv_vfmv_v_f_f32m1(0.0f,{vl});") +def rvv_broadcast_8xf32_0(dst: [f32][8] @ RVV, vl: size): + assert stride(dst, 0) == 1 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] = 0.0 @instr("{dst_data} = __riscv_vfmacc_vv_f32m1({dst_data}, {lhs_data}, {rhs_data},{vl});") def rvv_vfmacc_4xf32_4xf32( @@ -146,6 +196,18 @@ def rvv_vfmacc_4xf32_4xf32( for i in seq(0, vl): dst[i] += lhs[i] * rhs[i] +@instr("{dst_data} = __riscv_vfmacc_vv_f32m1({dst_data}, {lhs_data}, {rhs_data},{vl});") +def rvv_vfmacc_8xf32_8xf32( + dst: [f32][8] @ RVV, lhs: [f32][8] @ RVV, rhs: [f32][8] @ RVV, vl: size +): + assert stride(dst, 0) == 1 + assert stride(lhs, 0) == 1 + assert stride(rhs, 0) == 1 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] += lhs[i] * rhs[i] @instr("{dst_data} = __riscv_vfmacc_vf_f32m1{dst_data}, {rhs_data}, {lhs_data},{vl});") def rvv_vfmacc_4xf32_1xf32( @@ -160,6 +222,18 @@ def rvv_vfmacc_4xf32_1xf32( for i in seq(0, vl): dst[i] += lhs[i] * rhs[0] +@instr("{dst_data} = __riscv_vfmacc_vf_f32m1{dst_data}, {rhs_data}, {lhs_data},{vl});") +def rvv_vfmacc_8xf32_1xf32( + dst: [f32][8] @ RVV, lhs: [f32][8] @ RVV, rhs: [f32][1] @ DRAM, vl: size +): + assert stride(dst, 0) == 1 + assert stride(lhs, 0) == 1 + assert stride(rhs, 0) == 1 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] += lhs[i] * rhs[0] @instr("{dst_data} = __riscv_vfmacc_vf_f32m1{dst_data}, {lhs_data}, {rhs_data},{vl});") def rvv_vfmacc_1xf32_4xf32( @@ -173,3 +247,261 @@ def rvv_vfmacc_1xf32_4xf32( for i in seq(0, vl): dst[i] += lhs[0] * rhs[i] + +@instr("{dst_data} = __riscv_vfmacc_vf_f32m1{dst_data}, {lhs_data}, {rhs_data},{vl});") +def rvv_vfmacc_1xf32_8xf32( + dst: [f32][8] @ RVV, lhs: [f32][1] @ DRAM, rhs: [f32][8] @ RVV, vl: size +): + assert stride(dst, 0) == 1 + assert stride(lhs, 0) == 1 + assert stride(rhs, 0) == 1 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] += lhs[0] * rhs[i] + + + +@instr("{dst_data} = __riscv_vrgather_vx_f32m1({src_data}, {imm}, {vl});") +def rvv_gather_4xf32(dst: [f32][4] @ RVV, src: [f32][4] @ RVV, imm: index, vl: size): + assert stride(dst, 0) == 1 + assert stride(src, 0) == 1 + assert imm >= 0 + assert imm < 4 + assert vl >= 0 + assert vl <= 4 + + for i in seq(0, vl): + dst[i] = src[imm] + + +@instr("{dst_data} = __riscv_vrgather_vx_f32m1({src_data}, {imm}, {vl});") +def rvv_gather_8xf32(dst: [f32][8] @ RVV, src: [f32][8] @ RVV, imm: index, vl: size): + assert stride(dst, 0) == 1 + assert stride(src, 0) == 1 + assert imm >= 0 + assert imm < 8 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] = src[imm] + + + +# --------------------------------------------------------------------------- # +# f16 RVV intrinsics +# --------------------------------------------------------------------------- # + +# +# Load, Store, Broadcast, FMAdd, Mul, Add? +# +# float16 + + +@instr("{dst_data} = __riscv_vle16_v_f16m1(&{src_data},{vl});") +def rvv_vld_8xf16(dst: [f16][8] @ RVV, src: [f16][8] @ DRAM, vl: size): + assert stride(src, 0) == 1 + assert stride(dst, 0) == 1 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] = src[i] + +@instr("{dst_data} = __riscv_vle16_v_f16m1(&{src_data},{vl});") +def rvv_vld_16xf16(dst: [f16][16] @ RVV, src: [f16][16] @ DRAM, vl: size): + assert stride(src, 0) == 1 + assert stride(dst, 0) == 1 + assert vl >= 0 + assert vl <= 16 + + for i in seq(0, vl): + dst[i] = src[i] + + +@instr("__riscv_vse16_v_f16m1(&{dst_data}, {src_data},{vl});") +def rvv_vst_8xf16(dst: [f16][8] @ DRAM, src: [f16][8] @ RVV, vl: size): + assert stride(src, 0) == 1 + assert stride(dst, 0) == 1 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] = src[i] + + +@instr("__riscv_vse16_v_f16m1(&{dst_data}, {src_data},{vl});") +def rvv_vst_16xf16(dst: [f16][16] @ DRAM, src: [f16][16] @ RVV, vl: size): + assert stride(src, 0) == 1 + assert stride(dst, 0) == 1 + assert vl >= 0 + assert vl <= 16 + + for i in seq(0, vl): + dst[i] = src[i] + +@instr("{dst_data} = __riscv_vfmv_v_f_f16m1({src_data},{vl});") +def rvv_broadcast_8xf16(dst: [f16][8] @ RVV, src: [f16][1] @ DRAM, vl: size): + assert stride(dst, 0) == 1 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] = src[0] + +@instr("{dst_data} = __riscv_vfmv_v_f_f16m1({src_data},{vl});") +def rvv_broadcast_16xf16(dst: [f16][16] @ RVV, src: [f16][1] @ DRAM, vl: size): + assert stride(dst, 0) == 1 + assert vl >= 0 + assert vl <= 16 + + for i in seq(0, vl): + dst[i] = src[0] + + +@instr("{dst_data} = __riscv_vfmv_v_f_f16m1({src_data},{vl});") +def rvv_broadcast_8xf16_scalar(dst: [f16][8] @ RVV, src: f16 @ DRAM, vl: size): + assert stride(dst, 0) == 1 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] = src + +@instr("{dst_data} = __riscv_vfmv_v_f_f16m1({src_data},{vl});") +def rvv_broadcast_16xf16_scalar(dst: [f16][16] @ RVV, src: f16 @ DRAM, vl: size): + assert stride(dst, 0) == 1 + assert vl >= 0 + assert vl <= 16 + + for i in seq(0, vl): + dst[i] = src + + +@instr("{dst_data} = __riscv_vfmv_v_f_f16m1(0.0f,{vl});") +def rvv_broadcast_8xf16_0(dst: [f16][8] @ RVV, vl: size): + assert stride(dst, 0) == 1 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] = 0.0 + +@instr("{dst_data} = __riscv_vfmv_v_f_f16m1(0.0f,{vl});") +def rvv_broadcast_16xf16_0(dst: [f16][16] @ RVV, vl: size): + assert stride(dst, 0) == 1 + assert vl >= 0 + assert vl <= 16 + + for i in seq(0, vl): + dst[i] = 0.0 + + +@instr("{dst_data} = __riscv_vfmacc_vv_f16m1({dst_data}, {lhs_data}, {rhs_data},{vl});") +def rvv_vfmacc_8xf16_8xf16( + dst: [f16][8] @ RVV, lhs: [f16][8] @ RVV, rhs: [f16][8] @ RVV, vl: size +): + assert stride(dst, 0) == 1 + assert stride(lhs, 0) == 1 + assert stride(rhs, 0) == 1 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] += lhs[i] * rhs[i] + + +@instr("{dst_data} = __riscv_vfmacc_vv_f16m1({dst_data}, {lhs_data}, {rhs_data},{vl});") +def rvv_vfmacc_16xf16_16xf16( + dst: [f16][16] @ RVV, lhs: [f16][16] @ RVV, rhs: [f16][16] @ RVV, vl: size +): + assert stride(dst, 0) == 1 + assert stride(lhs, 0) == 1 + assert stride(rhs, 0) == 1 + assert vl >= 0 + assert vl <= 16 + + for i in seq(0, vl): + dst[i] += lhs[i] * rhs[i] + + +@instr("{dst_data} = __riscv_vfmacc_vf_f16m1{dst_data}, {rhs_data}, {lhs_data},{vl});") +def rvv_vfmacc_8xf16_1xf16( + dst: [f16][8] @ RVV, lhs: [f16][8] @ RVV, rhs: [f16][1] @ DRAM, vl: size +): + assert stride(dst, 0) == 1 + assert stride(lhs, 0) == 1 + assert stride(rhs, 0) == 1 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] += lhs[i] * rhs[0] + + +@instr("{dst_data} = __riscv_vfmacc_vf_f16m1{dst_data}, {rhs_data}, {lhs_data},{vl});") +def rvv_vfmacc_16xf16_1xf16( + dst: [f16][16] @ RVV, lhs: [f16][16] @ RVV, rhs: [f16][1] @ DRAM, vl: size +): + assert stride(dst, 0) == 1 + assert stride(lhs, 0) == 1 + assert stride(rhs, 0) == 1 + assert vl >= 0 + assert vl <= 16 + + for i in seq(0, vl): + dst[i] += lhs[i] * rhs[0] + +@instr("{dst_data} = __riscv_vfmacc_vf_f16m1{dst_data}, {lhs_data}, {rhs_data},{vl});") +def rvv_vfmacc_1xf16_8xf16( + dst: [f16][8] @ RVV, lhs: [f16][1] @ DRAM, rhs: [f16][8] @ RVV, vl: size +): + assert stride(dst, 0) == 1 + assert stride(lhs, 0) == 1 + assert stride(rhs, 0) == 1 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] += lhs[0] * rhs[i] + +@instr("{dst_data} = __riscv_vfmacc_vf_f16m1{dst_data}, {lhs_data}, {rhs_data},{vl});") +def rvv_vfmacc_1xf16_16xf16( + dst: [f16][16] @ RVV, lhs: [f16][1] @ DRAM, rhs: [f16][16] @ RVV, vl: size +): + assert stride(dst, 0) == 1 + assert stride(lhs, 0) == 1 + assert stride(rhs, 0) == 1 + assert vl >= 0 + assert vl <= 16 + + for i in seq(0, vl): + dst[i] += lhs[0] * rhs[i] + + + +@instr("{dst_data} = __riscv_vrgather_vx_f16m1({src_data}, {imm}, {vl});") +def rvv_gather_8xf16(dst: [f16][8] @ RVV, src: [f16][8] @ RVV, imm: index, vl: size): + assert stride(dst, 0) == 1 + assert stride(src, 0) == 1 + assert imm >= 0 + assert imm < 8 + assert vl >= 0 + assert vl <= 8 + + for i in seq(0, vl): + dst[i] = src[imm] + +@instr("{dst_data} = __riscv_vrgather_vx_f16m1({src_data}, {imm}, {vl});") +def rvv_gather_16xf16(dst: [f16][16] @ RVV, src: [f16][16] @ RVV, imm: index, vl: size): + assert stride(dst, 0) == 1 + assert stride(src, 0) == 1 + assert imm >= 0 + assert imm < 16 + assert vl >= 0 + assert vl <= 16 + + for i in seq(0, vl): + dst[i] = src[imm] diff --git a/src/exo/prec_analysis.py b/src/exo/prec_analysis.py index a73dc78d5..49f313882 100644 --- a/src/exo/prec_analysis.py +++ b/src/exo/prec_analysis.py @@ -14,6 +14,7 @@ def set_default_prec(name): "f32": T.f32, "f64": T.f64, "i8": T.i8, + "i16": T.i16, "i32": T.i32, } if name not in vals: diff --git a/src/exo/pyparser.py b/src/exo/pyparser.py index a8ab071fe..a9aef4834 100644 --- a/src/exo/pyparser.py +++ b/src/exo/pyparser.py @@ -409,6 +409,7 @@ def parse_alloc_typmem(self, node): "f64": UAST.F64(), "i8": UAST.INT8(), "ui8": UAST.UINT8(), + "i16": UAST.INT16(), "ui16": UAST.UINT16(), "i32": UAST.INT32(), } diff --git a/src/exo/query_asts.py b/src/exo/query_asts.py index 1b69e2717..269be96f7 100644 --- a/src/exo/query_asts.py +++ b/src/exo/query_asts.py @@ -42,6 +42,7 @@ f32() f64() i8() + i16() i32() bool() int() @@ -117,6 +118,9 @@ class f64(Type): class i8(Type): pass +@_dataclass +class i16(Type): + pass @_dataclass class i32(Type): diff --git a/src/exo/reflection.py b/src/exo/reflection.py index 7eda4da8d..58ba47180 100644 --- a/src/exo/reflection.py +++ b/src/exo/reflection.py @@ -238,6 +238,8 @@ def map_type(self, typ): return QAST.f64() elif typ == T.i8: return QAST.i8() + elif typ == T.i16: + return QAST.i16() elif typ == T.i32: return QAST.i32() elif typ == T.bool: diff --git a/src/exo/typecheck.py b/src/exo/typecheck.py index 5ca58df91..0f25271b0 100644 --- a/src/exo/typecheck.py +++ b/src/exo/typecheck.py @@ -498,6 +498,8 @@ def check_e(self, e): typ = T.uint8 elif lhs.type == T.uint16: typ = T.uint16 + elif lhs.type == T.int16: + typ = T.int16 elif lhs.type == T.int32: typ = T.int32 elif rhs.type.is_real_scalar(): @@ -612,6 +614,7 @@ def check_e(self, e): UAST.F64: T.f64, UAST.INT8: T.int8, UAST.UINT8: T.uint8, + UAST.INT16: T.int16, UAST.UINT16: T.uint16, UAST.INT32: T.int32, UAST.Bool: T.bool,