diff --git a/src/exo/API.py b/src/exo/API.py index d00ff5de3..8aa4e6fdd 100644 --- a/src/exo/API.py +++ b/src/exo/API.py @@ -13,7 +13,7 @@ from .backend.LoopIR_compiler import run_compile, compile_to_strings from .core.configs import Config from .frontend.boundscheck import CheckBounds -from .core.memory import Memory +from .core.memory import MemWin, Memory, SpecialWindow from .frontend.parse_fragment import parse_fragment from .frontend.pattern_match import match_pattern from .core.prelude import * diff --git a/src/exo/API_cursors.py b/src/exo/API_cursors.py index e9b090544..32f1e8d3a 100644 --- a/src/exo/API_cursors.py +++ b/src/exo/API_cursors.py @@ -3,13 +3,13 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Any +from typing import List, Any, Optional, Type from . import API # TODO: remove this circular import from .API_types import ExoType, loopir_type_to_exotype from .core.LoopIR import LoopIR from .core.configs import Config -from .core.memory import Memory +from .core.memory import MemWin, Memory, SpecialWindow from .core import internal_cursors as C from .frontend.pattern_match import match_pattern @@ -65,7 +65,7 @@ class Cursor(ABC): | For( name : str, hi : Expr, body : Block ) | Alloc( name : str, mem : Memory? ) | Call( subproc : Procedure, args : ExprList ) - | WindowStmt( name : str, winexpr : WindowExpr ) + | WindowStmt( name : str, winexpr : WindowExpr, special_window : SpecialWindow? ) Expr ::= Read( name : str, idx : ExprList ) | ReadConfig( config : Config, field : str ) @@ -203,13 +203,13 @@ def name(self) -> str: return self._impl._node.name.name() - def mem(self) -> Memory: + def mem(self) -> MemWin: assert isinstance(self._impl, C.Node) assert isinstance(self._impl._node, LoopIR.fnarg) assert not self._impl._node.type.is_indexable() mem = self._impl._node.mem - assert issubclass(mem, Memory) + assert issubclass(mem, MemWin) return mem def is_tensor(self) -> bool: @@ -657,7 +657,7 @@ class WindowStmtCursor(StmtCursor): """ Cursor pointing to a window declaration statement: ``` - name = winexpr + name = winexpr @ special_window ``` """ @@ -673,6 +673,13 @@ def winexpr(self) -> ExprCursor: return WindowExprCursor(self._impl._child_node("rhs"), self._proc) + def special_window(self) -> Optional[Type[SpecialWindow]]: + assert isinstance(self._impl, C.Node) + assert isinstance(self._impl._node, LoopIR.WindowStmt) + special_window = self._impl._node.special_window + assert issubclass(special_window, SpecialWindow) + return special_window + # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # diff --git a/src/exo/API_scheduling.py b/src/exo/API_scheduling.py index c35bc280a..f61437730 100644 --- a/src/exo/API_scheduling.py +++ b/src/exo/API_scheduling.py @@ -1161,6 +1161,7 @@ def set_window(proc, cursor, is_window=True): return Procedure(ir, _provenance_eq_Procedure=proc, _forward=fwd) +# TODO support SpecialWindow for arg cursor (but not alloc) @sched_op([ArgOrAllocCursorA, MemoryA]) def set_memory(proc, cursor, memory_type): """ diff --git a/src/exo/__init__.py b/src/exo/__init__.py index 95fe0c050..3c63d7d62 100644 --- a/src/exo/__init__.py +++ b/src/exo/__init__.py @@ -10,7 +10,15 @@ from .rewrite.LoopIR_scheduling import SchedulingError from .frontend.parse_fragment import ParseFragmentError from .core.configs import Config -from .core.memory import Memory, DRAM +from .core.memory import ( + MemWin, + Memory, + SpecialWindow, + DRAM, + WindowStructCtx, + SpecialWindowFromMemoryCtx, + memwin_template, +) from .core.extern import Extern from . import stdlib @@ -25,9 +33,14 @@ "instr", "config", "Config", + "MemWin", "Memory", - "Extern", + "SpecialWindow", + "WindowStructCtx", + "SpecialWindowFromMemoryCtx", + "memwin_template", "DRAM", + "Extern", "SchedulingError", "ParseFragmentError", # diff --git a/src/exo/backend/LoopIR_compiler.py b/src/exo/backend/LoopIR_compiler.py index 090a215de..81588cc01 100644 --- a/src/exo/backend/LoopIR_compiler.py +++ b/src/exo/backend/LoopIR_compiler.py @@ -5,11 +5,21 @@ from collections import defaultdict from dataclasses import dataclass from pathlib import Path +from typing import List from ..core.LoopIR import LoopIR, LoopIR_Do, get_writes_of_stmts, T, CIR from ..core.configs import ConfigError from .mem_analysis import MemoryAnalysis -from ..core.memory import MemGenError, Memory, DRAM, StaticMemory +from ..core.memory import ( + MemGenError, + MemWin, + Memory, + SpecialWindow, + DRAM, + StaticMemory, + WindowStructCtx, + SpecialWindowFromMemoryCtx, +) from .parallel_analysis import ParallelAnalysis from .prec_analysis import PrecisionAnalysis from ..core.prelude import * @@ -21,6 +31,16 @@ def sanitize_str(s): return re.sub(r"\W", "_", s) +T_shorthand = { + T.f16: "f16", + T.f32: "f32", + T.f64: "f64", + T.i8: "i8", + T.ui8: "ui8", + T.ui16: "ui16", + T.i32: "i32", +} + # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # @@ -170,16 +190,16 @@ def walk(proc, visited): return list(reversed(all_procs)) -class LoopIR_FindMems(LoopIR_Do): +class LoopIR_FindMemWins(LoopIR_Do): def __init__(self, proc): - self._mems = set() + self._memwins = set() for a in proc.args: if a.mem: - self._mems.add(a.mem) + self._memwins.add(a.mem) super().__init__(proc) def result(self): - return self._mems + return self._memwins # to improve efficiency def do_e(self, e): @@ -188,7 +208,10 @@ def do_e(self, e): def do_s(self, s): if isinstance(s, LoopIR.Alloc): if s.mem: - self._mems.add(s.mem) + self._memwins.add(s.mem) + elif isinstance(s, LoopIR.WindowStmt): + if s.special_window: + self._memwins.add(s.special_window) else: super().do_s(s) @@ -239,12 +262,11 @@ def do_t(self, t): pass -def find_all_mems(proc_list): - mems = set() +def find_all_memwins(proc_list): + memwins = set() for p in proc_list: - mems.update(LoopIR_FindMems(p).result()) - - return [m for m in mems] + memwins.update(LoopIR_FindMemWins(p).result()) + return memwins def find_all_externs(proc_list): @@ -267,50 +289,14 @@ def find_all_configs(proc_list): # --------------------------------------------------------------------------- # -@dataclass(frozen=True) +@dataclass class WindowStruct: name: str definition: str - - -@functools.cache -def _window_struct(typename, ctype, n_dims, is_const) -> WindowStruct: - const_kwd = "const " if is_const else "" - const_suffix = "c" if is_const else "" - - sname = f"exo_win_{n_dims}{typename}{const_suffix}" - sdef = ( - f"struct {sname}{{\n" - f" {const_kwd}{ctype} * const data;\n" - f" const int_fast32_t strides[{n_dims}];\n" - f"}};" - ) - - sdef_guard = sname.upper() - sdef = f"""#ifndef {sdef_guard} -#define {sdef_guard} -{sdef} -#endif""" - - return WindowStruct(sname, sdef) - - -def window_struct(base_type, n_dims, is_const) -> WindowStruct: - assert n_dims >= 1 - - _window_struct_shorthand = { - T.f16: "f16", - T.f32: "f32", - T.f64: "f64", - T.i8: "i8", - T.ui8: "ui8", - T.ui16: "ui16", - T.i32: "i32", - } - - return _window_struct( - _window_struct_shorthand[base_type], base_type.ctype(), n_dims, is_const - ) + dataptr: str + separate_dataptr: bool + is_const: bool + emit_definition: bool # --------------------------------------------------------------------------- # @@ -371,11 +357,10 @@ def from_lines(x): # Header contents ctxt_name, ctxt_def = _compile_context_struct(find_all_configs(proc_list), lib_name) - struct_defns = set() + window_struct_cache = WindowStructCache() public_fwd_decls = [] # Body contents - memory_code = _compile_memories(find_all_mems(proc_list)) private_fwd_decls = [] proc_bodies = [] instrs_global = [] @@ -412,9 +397,10 @@ def from_lines(x): p = WindowAnalysis().apply_proc(p) p = MemoryAnalysis().run(p) - comp = Compiler(p, ctxt_name, is_public_decl=is_public_decl) + comp = Compiler( + p, ctxt_name, window_struct_cache, is_public_decl=is_public_decl + ) d, b = comp.comp_top() - struct_defns |= comp.struct_defns() needed_helpers |= comp.needed_helpers() if is_public_decl: @@ -426,8 +412,13 @@ def from_lines(x): analyzed_proc_list.append(p) - # Structs are just blobs of code... still sort them for output stability - struct_defns = [x.definition for x in sorted(struct_defns, key=lambda x: x.name)] + # Memories and structs are just blobs of code... + # still sort them for output stability + header_memwins, header_memwin_code, body_memwin_code = _compile_memwins(proc_list) + ( + header_struct_defns, + body_struct_defns, + ) = window_struct_cache.sorted_header_body_definitions(header_memwins) header_contents = f""" #include @@ -452,7 +443,8 @@ def from_lines(x): #endif {from_lines(ctxt_def)} -{from_lines(struct_defns)} +{from_lines(header_memwin_code)} +{from_lines(header_struct_defns)} {from_lines(public_fwd_decls)} """ @@ -462,7 +454,8 @@ def from_lines(x): body_contents = [ helper_code, instrs_global, - memory_code, + body_memwin_code, + body_struct_defns, extern_code, private_fwd_decls, proc_bodies, @@ -482,11 +475,25 @@ def _compile_externs(externs): return extern_code -def _compile_memories(mems): - memory_code = [] - for m in sorted(mems, key=lambda x: x.name()): - memory_code.append(m.global_()) - return memory_code +def _compile_memwins(proc_list): + """Return (header memwin set, header memwin code, C body memwin code)""" + all_memwins = find_all_memwins(proc_list) + + # Memories used as part of proc args must be defined in public header + header_memwins = set() + for p in proc_list: + if p.instr is None: + for arg in p.args: + memwin = arg.mem or DRAM + assert memwin in all_memwins + header_memwins.add(arg.mem or DRAM) + + header_memwin_code = [] + body_memwin_code = [] + for m in sorted(all_memwins, key=lambda x: x.name()): + code_list = header_memwin_code if m in header_memwins else body_memwin_code + code_list.append(m.global_()) + return header_memwins, header_memwin_code, body_memwin_code def _compile_context_struct(configs, lib_name): @@ -522,8 +529,9 @@ def _compile_context_struct(configs, lib_name): class Compiler: - def __init__(self, proc, ctxt_name, *, is_public_decl): + def __init__(self, proc, ctxt_name, window_struct_cache, *, is_public_decl): assert isinstance(proc, LoopIR.proc) + assert isinstance(window_struct_cache, WindowStructCache) self.proc = proc self.ctxt_name = ctxt_name @@ -536,7 +544,7 @@ def __init__(self, proc, ctxt_name, *, is_public_decl): self._lines = [] self._scalar_refs = set() self._needed_helpers = set() - self.window_defns = set() + self.window_struct_cache = window_struct_cache self._known_strides = {} assert self.proc.name is not None, "expected names for compilation" @@ -553,25 +561,9 @@ def __init__(self, proc, ctxt_name, *, is_public_decl): for a in proc.args: mem = a.mem if a.type.is_numeric() else None name_arg = self.new_varname(a.name, typ=a.type, mem=mem) - if a.type in (T.size, T.index, T.bool, T.stride): - arg_strs.append(f"{a.type.ctype()} {name_arg}") - typ_comments.append(f"{name_arg} : {a.type}") - # setup, arguments - else: - assert a.type.is_numeric() - assert a.type.basetype() != T.R - if a.type.is_real_scalar(): - self._scalar_refs.add(a.name) - if a.type.is_win(): - wintyp = self.get_window_type(a) - arg_strs.append(f"struct {wintyp} {name_arg}") - else: - const_kwd = "const " if a.name not in self.non_const else "" - ctyp = a.type.basetype().ctype() - arg_strs.append(f"{const_kwd}{ctyp}* {name_arg}") - mem = f" @{a.mem.name()}" if a.mem else "" - comment_str = f"{name_arg} : {a.type}{mem}" - typ_comments.append(comment_str) + if a.type.is_real_scalar(): + self._scalar_refs.add(a.name) + self.append_fnarg_decl(a, name_arg, arg_strs, typ_comments) for pred in proc.preds: if isinstance(pred, LoopIR.Const): @@ -617,6 +609,43 @@ def __init__(self, proc, ctxt_name, *, is_public_decl): self.proc_decl = proc_decl self.proc_def = proc_def + def append_fnarg_decl(self, a: LoopIR.fnarg, name_arg: str, arg_strs, typ_comments): + """Compile a LoopIR.fnarg to C function argument declaration(s). + + Appends function arguments (e.g. `int* foo`) and type comments + to the given lists, respectively. + Side effect: triggers compilation of memory definitions + and window struct declarations as needed. + """ + assert isinstance(a, LoopIR.fnarg) + mem = a.mem if a.type.is_numeric() else None + if a.type in (T.size, T.index, T.bool, T.stride): + arg_strs.append(f"{a.type.ctype()} {name_arg}") + typ_comments.append(f"{name_arg} : {a.type}") + # setup, arguments + else: + assert a.type.is_numeric() + assert a.type.basetype() != T.R + is_const = a.name not in self.non_const + if a.type.is_real_scalar(): + arg_strs.append( + f"{'const ' if is_const else ''}{a.type.ctype()}* {name_arg}" + ) + else: + assert a.type.is_tensor_or_window() + window_struct = self.get_window_struct( + a, mem or DRAM, is_const, a.type.is_win() + ) + if a.type.is_win(): + if window_struct.separate_dataptr: + arg_strs.append(f"{window_struct.dataptr} exo_data_{name_arg}") + arg_strs.append(f"struct {window_struct.name} {name_arg}") + else: + arg_strs.append(f"{window_struct.dataptr} {name_arg}") + memstr = f" @{a.mem.name()}" if a.mem else "" + comment_str = f"{name_arg} : {a.type}{memstr}" + typ_comments.append(comment_str) + def static_memory_check(self, proc): def allocates_static_memory(stmts): check = False @@ -660,14 +689,16 @@ def comp_stmts(self, stmts): def comp_top(self): return self.proc_decl, self.proc_def - def struct_defns(self): - return self.window_defns - def needed_helpers(self): return self._needed_helpers def new_varname(self, symbol, typ, mem=None): strnm = str(symbol) + + # Reserve "exo_" prefix for internal use. + if strnm.lower().startswith("exo_"): + strnm = "exo_user_" + strnm + if strnm not in self.names: pass else: @@ -685,6 +716,7 @@ def new_varname(self, symbol, typ, mem=None): self.env[symbol] = strnm self.envtyp[symbol] = typ if mem is not None: + assert issubclass(mem, MemWin) self.mems[symbol] = mem else: self.mems[symbol] = DRAM @@ -792,6 +824,10 @@ def get_strides(self, name: Sym, typ) -> CIR: else: return self.tensor_strides(typ.shape()) + def get_strides_s(self, name: Sym, typ) -> List[str]: + all_strides = self.get_strides(name, typ) + return [self.comp_cir(simplify_cir(i), self.env, prec=0) for i in all_strides] + def get_idx_offset(self, name: Sym, typ, idx) -> CIR: strides = self.get_strides(name, typ) assert len(strides) == len(idx) @@ -802,9 +838,10 @@ def get_idx_offset(self, name: Sym, typ, idx) -> CIR: return acc - def get_window_type(self, typ, is_const=None): + def get_window_struct(self, node, mem, is_const=None, emit_definition=True): + typ = node.type assert isinstance(typ, T.Window) or ( - isinstance(typ, LoopIR.fnarg) and typ.type.is_win() + isinstance(node, LoopIR.fnarg) and typ.is_tensor_or_window() ) if isinstance(typ, T.Window): @@ -814,13 +851,13 @@ def get_window_type(self, typ, is_const=None): is_const = typ.src_buf not in self.non_const else: base = typ.type.basetype() - n_dims = len(typ.type.shape()) + n_dims = len(typ.shape()) if is_const is None: - is_const = typ.name not in self.non_const + is_const = node.name not in self.non_const - win = window_struct(base, n_dims, is_const) - self.window_defns.add(win) - return win.name + return self.window_struct_cache.get( + mem, base, n_dims, is_const, node.srcinfo, emit_definition + ) def comp_s(self, s): if isinstance(s, LoopIR.Pass): @@ -843,7 +880,7 @@ def comp_s(self, s): rhs = f"({lbtyp.ctype()})({rhs})" - mem: Memory = self.mems[s.name] + mem: MemWin = self.mems[s.name] if isinstance(s, LoopIR.Assign): self.add_line(mem.write(s, lhs, rhs)) else: @@ -870,12 +907,60 @@ def comp_s(self, s): self.add_line(f"ctxt->{nm}.{s.field} = {rhs};") elif isinstance(s, LoopIR.WindowStmt): - win_struct = self.get_window_type(s.rhs.type) - rhs = self.comp_e(s.rhs) - assert isinstance(s.rhs, LoopIR.WindowExpr) - mem = self.mems[s.rhs.name] - name = self.new_varname(s.name, typ=s.rhs.type, mem=mem) - self.add_line(f"struct {win_struct} {name} = {rhs};") + rhs = s.rhs + assert isinstance(rhs, LoopIR.WindowExpr) + input_winmem = self.mems[rhs.name] + input_win_struct = self.get_window_struct(rhs, input_winmem) + ( + w_type, + w_def, + d_type, + d_def, + layout, + separate_dataptr, + ) = self.unpack_window_expr(rhs, input_winmem, input_win_struct.is_const) + + output_winmem = s.special_window or input_winmem + name = self.new_varname(s.name, typ=rhs.type, mem=output_winmem) + + if not s.special_window: + output_win_struct = input_win_struct + else: + # Special case, creating a special window + # We pass the temporary expressions from unpack_window_expr to + # the SpecialWindow creation callback. + assert issubclass(output_winmem, SpecialWindow) + assert issubclass(input_winmem, output_winmem.source_memory_type()) + tensor_type = rhs.type.as_tensor_type() + scalar_type = tensor_type.basetype() + output_win_struct = self.get_window_struct(rhs, output_winmem) + ctx = SpecialWindowFromMemoryCtx( + d_def, + layout, + output_win_struct.dataptr, + output_win_struct.name, + tensor_type, + self.shape_strs(tensor_type.shape()), + output_win_struct.is_const, + scalar_type.ctype(), + T_shorthand[scalar_type], + s.srcinfo, + ) + tmp = output_winmem.from_memory(ctx) + + # Substitute window definition for codegen, replacing temporary window. + separate_dataptr = output_winmem.separate_dataptr() + if separate_dataptr: + assert len(tmp) == 2 + d_def, w_def = tmp + else: + assert isinstance(tmp, str) + d_def, w_def = None, tmp + + if separate_dataptr: + self.add_line(f"{output_win_struct.dataptr} exo_data_{name} = {d_def};") + self.add_line(f"struct {output_win_struct.name} {name} = {w_def};") + elif isinstance(s, LoopIR.If): cond = self.comp_e(s.cond) self.add_line(f"if ({cond}) {{") @@ -927,47 +1012,90 @@ def comp_s(self, s): assert all( a.type.is_win() == fna.type.is_win() for a, fna in zip(s.args, s.f.args) ) - args = [self.comp_fnarg(e, s.f, i) for i, e in enumerate(s.args)] + arg_tups = [self.comp_fnarg(e, s.f, i) for i, e in enumerate(s.args)] if s.f.instr is not None: d = dict() - assert len(s.f.args) == len(args) - for i in range(len(args)): + assert len(s.f.args) == len(arg_tups) + for i in range(len(arg_tups)): arg_name = str(s.f.args[i].name) - d[arg_name] = f"({args[i]})" + c_args, instr_data, instr_layout = arg_tups[i] arg_type = s.args[i].type if arg_type.is_win(): - assert isinstance(s.args[i], LoopIR.WindowExpr) - data, _ = self.window_struct_fields(s.args[i]) - d[f"{arg_name}_data"] = data + if not isinstance(s.args[i], LoopIR.WindowExpr): + # comp_fnarg requires this for {arg_name}_data + raise TypeError( + f"{s.srcinfo}: Argument {arg_name} must be a " + f"window expression created at the call site " + f"of {s.f.name}" + ) + # c_args = (window,) or (dataptr, layout) (depending on + # separate_dataptr); [-1] gets the window/layout + d[arg_name] = f"({c_args[-1]})" + d[f"{arg_name}_data"] = instr_data + d[f"{arg_name}_layout"] = instr_layout + # Special case for AMX instrs d[f"{arg_name}_int"] = self.env[s.args[i].name] + assert instr_data else: - d[f"{arg_name}_data"] = f"({args[i]})" + assert ( + len(c_args) == 1 + ), "didn't expect multiple c_args for non-window" + arg = f"({c_args[0]})" + d[arg_name] = arg + # Exo 1 does this; unclear why for non-windows + d[f"{arg_name}_data"] = arg self.add_line(f"{s.f.instr.c_instr.format(**d)}") else: fname = s.f.name - args = ["ctxt"] + args + args = ["ctxt"] + for tups in arg_tups: + c_args = tups[0] + args.extend(c_args) self.add_line(f"{fname}({','.join(args)});") else: assert False, "bad case" def comp_fnarg(self, e, fn, i, *, prec=0): + """Returns (c_args : tuple, + instr_data : Optional[str], + instr_layout : Optional[str]) + + c_args is a tuple (length 1 or 2) of formatted arguments. + Length 2 only occurs for separate_dataptr windows: (dataptr, layout). + + instr_data is for formatting c_instr windows; passed as {arg_name}_data. + This is needed both for compatibility with Exo 1 and for allowing + access to the dataptr when separate_dataptr is True. + + instr_layout is similar, passed as {arg_name}_layout. + This is an untyped initializer for the window layout (e.g. strides). + """ if isinstance(e, LoopIR.Read): assert not e.idx rtyp = self.envtyp[e.name] if rtyp.is_indexable(): - return self.env[e.name] + return (self.env[e.name],), None, None elif rtyp is T.bool: - return self.env[e.name] + return (self.env[e.name],), None, None elif rtyp is T.stride: - return self.env[e.name] + return (self.env[e.name],), None, None elif e.name in self._scalar_refs: - return self.env[e.name] + return (self.env[e.name],), None, None elif rtyp.is_tensor_or_window(): - return self.env[e.name] + c_window = self.env[e.name] + mem = fn.args[i].mem + if mem and mem.separate_dataptr(): + # This data path is exercised for calling normal + # functions, but the omitted instr_data is only + # used for instr, which can't use this code path. + c_data = "exo_data_" + c_syntax + return (c_data, c_window), None, None + else: + return (c_window,), None, None else: assert rtyp.is_real_scalar() - return f"&{self.env[e.name]}" + return (f"&{self.env[e.name]}",), None, None elif isinstance(e, LoopIR.WindowExpr): if isinstance(fn, LoopIR.proc): callee_buf = fn.args[i].name @@ -976,11 +1104,15 @@ def comp_fnarg(self, e, fn, i, *, prec=0): ) else: raise NotImplementedError("Passing windows to externs") - win_struct = self.get_window_type(e.type, is_const) - data, strides = self.window_struct_fields(e) - return f"(struct {win_struct}){{ &{data}, {{ {strides} }} }}" + _, w_def, _, d_def, layout, separate_dataptr = self.unpack_window_expr( + e, self.mems[e.name], is_const + ) + if separate_dataptr: + return (d_def, w_def), d_def, layout + else: + return (w_def,), d_def, layout else: - return self.comp_e(e, prec) + return (self.comp_e(e, prec),), None, None def comp_e(self, e, prec=0): if isinstance(e, LoopIR.Read): @@ -988,7 +1120,7 @@ def comp_e(self, e, prec=0): if rtyp.is_indexable() or rtyp is T.bool or rtyp == T.stride: return self.env[e.name] - mem: Memory = self.mems[e.name] + mem: MemWin = self.mems[e.name] if not mem.can_read(): raise MemGenError( @@ -1004,9 +1136,12 @@ def comp_e(self, e, prec=0): return self.access_str(e.name, e.idx) elif isinstance(e, LoopIR.WindowExpr): - win_struct = self.get_window_type(e.type) - data, strides = self.window_struct_fields(e) - return f"(struct {win_struct}){{ &{data}, {{ {strides} }} }}" + # WindowExpr needs to be handled differently depending on usage + # * WindowStmt + # * Passing to function + # * Passing to instr + # see unpack_window_expr and get strings from there + assert 0, "Unexpected standalone WindowExpr" elif isinstance(e, LoopIR.Const): if isinstance(e.val, bool): @@ -1067,14 +1202,27 @@ def comp_e(self, e, prec=0): else: assert False, "bad case" - def _call_static_helper(self, helper, *args): - self._needed_helpers.add(helper) - return f'{helper}({", ".join(map(str, args))})' + def unpack_window_expr(self, e: LoopIR.WindowExpr, src_memwin: type, is_const=None): + """(w_type, w_def, d_type, d_def, layout, separate_dataptr) + + w_type, w_def: C typename and initialization for window struct + + d_type: C typename for data pointer + + d_def: "data" passed through from src_memwin.window(...) + + layout: untyped C braced initializer for layout portion of window + + separate_dataptr: If True, the window is defined with a + separate data pointer {d_type} {name} = {d_def} + """ + win_struct = self.get_window_struct(e, src_memwin, is_const) + w_type = win_struct.name + d_type = win_struct.dataptr + separate_dataptr = win_struct.separate_dataptr - def window_struct_fields(self, e): base = self.env[e.name] - basetyp = self.envtyp[e.name] - mem: Memory = self.mems[e.name] + basetyp = self.envtyp[e.name].as_tensor_type() # compute offset to new data pointer def w_lo(w): @@ -1084,13 +1232,144 @@ def w_lo(w): idxs = [self.comp_cir(simplify_cir(i), self.env, prec=0) for i in cirs] # compute new window strides - all_strides = self.get_strides(e.name, basetyp) - all_strides_s = [ - self.comp_cir(simplify_cir(i), self.env, prec=0) for i in all_strides - ] + all_strides_s = self.get_strides_s(e.name, basetyp) assert 0 < len(all_strides_s) == len(e.idx) - dataptr = mem.window(basetyp, base, idxs, all_strides_s, e.srcinfo) - strides = ", ".join( - s for s, w in zip(all_strides_s, e.idx) if isinstance(w, LoopIR.Interval) + if separate_dataptr and basetyp.is_win(): + window_in_expr = "exo_data_" + base, base + else: + window_in_expr = base + callback_result = src_memwin.window( + basetyp, window_in_expr, idxs, all_strides_s, e.srcinfo ) - return dataptr, strides + if isinstance(callback_result, str): + # Base case, no custom layout + assert ( + not separate_dataptr + ), "MemWin must define custom layout for separate_dataptr" + strides = ", ".join( + s + for s, w in zip(all_strides_s, e.idx) + if isinstance(w, LoopIR.Interval) + ) + layout = f"{{ {strides} }}" + d_def = callback_result + w_def = f"(struct {w_type}){{ &{d_def}, {layout} }}" + else: + # Custom layout case + assert len(callback_result) == 2 + d_def, layout = callback_result + if separate_dataptr: + w_def = f"(struct {w_type}) {layout}" + else: + w_def = f"(struct {w_type}){{ {d_def}, {layout} }}" # not &data + # This could be an optional MemWin.window_remove_dims(...) callback + if any(isinstance(w, LoopIR.Point) for w in e.idx): + raise MemGenError( + f"{e.srcinfo}: {src_memwin.name()} window from {e.name} doesn't support removing dimensions (single Point coordinate in window indices)" + ) + + return w_type, w_def, d_type, d_def, layout, separate_dataptr + + def _call_static_helper(self, helper, *args): + self._needed_helpers.add(helper) + return f'{helper}({", ".join(map(str, args))})' + + +# --------------------------------------------------------------------------- # +# --------------------------------------------------------------------------- # +# Cached collection of window struct definitions + + +class WindowStructCache(object): + __slots__ = ["_key_to_name", "_name_to_struct"] + + def __init__(self): + self._key_to_name = {} + self._name_to_struct = {} + + def _add_to_cache(self, key_tuple, srcinfo) -> WindowStruct: + memwin, base_type, n_dims, is_const = key_tuple + type_shorthand = T_shorthand[base_type] + separate_dataptr = memwin.separate_dataptr() + + ctx = WindowStructCtx( + base_type.ctype(), + type_shorthand, + n_dims, + is_const, + separate_dataptr, + srcinfo, + ) + c_dataptr, c_window = memwin.window_definition(ctx) + + assert isinstance(c_dataptr, str) + assert isinstance(c_window, str) + assert isinstance(separate_dataptr, bool) + + assert ctx._struct_name is not None, "MemWin didn't name the struct" + sname = ctx._struct_name + + self._key_to_name[key_tuple] = sname + + sdef = f"""#ifndef {ctx._guard_macro} +#define {ctx._guard_macro} +{c_window} +#endif""" + + v = self._name_to_struct.get(sname) + + if v is None: + v = WindowStruct( + ctx._struct_name, + sdef, + c_dataptr, + separate_dataptr, + is_const, + False, # emit_definition flag; modified outside this function + ) + self._name_to_struct[sname] = v + elif v.definition != sdef: + # Since windows are keyed based on MemWin type, and derived MemWin + # types inherit an identical window struct if not overriden, + # it's valid to have a struct name collision here. + # But we validate that the collision is due to a duplicate + # identical struct, and not a true name incompatibility. + for key_tuple2, sname2 in self._key_to_name.values(): + if sname2 == sname: + memwin2, base_type2, n_dims2, is_const2 = key_tuple2 + type_shorthand2 = T_shorthand[base_type2] + raise ValueError( + f"""Window name collision for {sname}: +{memwin.name()}, {type_shorthand}, n_dims={n_dims}, is_const={is_const}; +{memwin2.name()}, {type_shorthand2}, n_dims={n_dims2}, is_const={is_const2}""" + ) + + return v + + def get( + self, memwin, base_type, n_dims, is_const, srcinfo, emit_definition + ) -> WindowStruct: + key_tuple = (memwin, base_type, n_dims, is_const) + sname = self._key_to_name.get(key_tuple) + if sname is None: + v = self._add_to_cache(key_tuple, srcinfo) + else: + v = self._name_to_struct[sname] + v.emit_definition |= emit_definition + return v + + def sorted_header_body_definitions(self, header_memwins): + header_snames = set() + for key_tuple, sname in self._key_to_name.items(): + memwin, _, _, _ = key_tuple + if memwin in header_memwins: + header_snames.add(sname) + + sorted_pairs = sorted(self._name_to_struct.items()) + h_definitions = [] + c_definitions = [] + for sname, struct in sorted_pairs: + if struct.emit_definition: + lst = h_definitions if sname in header_snames else c_definitions + lst.append(struct.definition) + return h_definitions, c_definitions diff --git a/src/exo/backend/mem_analysis.py b/src/exo/backend/mem_analysis.py index 39eaf267c..8f46dacf2 100644 --- a/src/exo/backend/mem_analysis.py +++ b/src/exo/backend/mem_analysis.py @@ -1,7 +1,7 @@ from collections import ChainMap from ..core.LoopIR import LoopIR -from ..core.memory import Memory +from ..core.memory import MemWin, Memory, SpecialWindow # --------------------------------------------------------------------------- # @@ -23,7 +23,7 @@ def run(self, proc): for a in proc.args: if a.type.is_numeric(): mem = a.mem - assert issubclass(mem, Memory) + assert issubclass(mem, MemWin) self.mem_env[a.name] = mem self.push() @@ -134,7 +134,7 @@ def mem_s(self, s): elif styp is LoopIR.WindowStmt: mem = self.get_e_mem(s.rhs) - self.mem_env[s.name] = mem + self.mem_env[s.name] = s.special_window or mem return s elif styp is LoopIR.Call: @@ -142,7 +142,7 @@ def mem_s(self, s): for ca, sa in zip(s.args, s.f.args): if sa.type.is_numeric(): smem = sa.mem - assert issubclass(smem, Memory) + assert issubclass(smem, MemWin) cmem = self.get_e_mem(ca) if not issubclass(cmem, smem): raise TypeError( diff --git a/src/exo/core/LoopIR.py b/src/exo/core/LoopIR.py index ee67bbd31..535c4587b 100644 --- a/src/exo/core/LoopIR.py +++ b/src/exo/core/LoopIR.py @@ -6,7 +6,7 @@ from .extern import Extern from .configs import Config -from .memory import Memory +from .memory import MemWin, Memory, SpecialWindow from .prelude import Sym, SrcInfo, extclass @@ -70,7 +70,7 @@ def __new__(cls, op): fnarg = ( sym name, type type, - mem? mem, + memwin? mem, srcinfo srcinfo ) stmt = Assign( sym name, type type, expr* idx, expr rhs ) @@ -82,7 +82,7 @@ def __new__(cls, op): | Alloc( sym name, type type, mem mem ) | Free( sym name, type type, mem mem ) | Call( proc f, expr* args ) - | WindowStmt( sym name, expr rhs ) + | WindowStmt( sym name, expr rhs, special_window? special_window ) attributes( srcinfo srcinfo ) loop_mode = Seq() @@ -141,7 +141,9 @@ def __new__(cls, op): ext_types={ "name": validators.instance_of(Identifier, convert=True), "sym": Sym, + "memwin": Type[MemWin], "mem": Type[Memory], + "special_window": Type[SpecialWindow], "extern": Extern, "config": Config, "binop": validators.instance_of(Operator, convert=True), @@ -184,7 +186,7 @@ def __new__(cls, op): fnarg = ( sym name, type type, - mem? mem, + memwin? mem, srcinfo srcinfo ) stmt = Assign ( sym name, expr* idx, expr rhs ) @@ -203,7 +205,7 @@ def __new__(cls, op): | USub ( expr arg ) -- i.e. -(...) | BinOp ( op op, expr lhs, expr rhs ) | Extern( extern f, expr* args ) - | WindowExpr( sym name, w_access* idx ) + | WindowExpr( sym name, w_access* idx, special_window? special_window ) | StrideExpr( sym name, int dim ) | ParRange( expr lo, expr hi ) -- only use for loop cond | SeqRange( expr lo, expr hi ) -- only use for loop cond @@ -232,7 +234,9 @@ def __new__(cls, op): ext_types={ "name": validators.instance_of(Identifier, convert=True), "sym": Sym, + "memwin": Type[MemWin], "mem": Type[Memory], + "special_window": Type[SpecialWindow], "extern": Extern, "config": Config, "loopir_proc": LoopIR.proc, @@ -407,6 +411,19 @@ class T: # type helper functions +@extclass(T.Tensor) +def as_tensor_type(t): + return t + + +@extclass(T.Window) +def as_tensor_type(t): + return t.as_tensor + + +del as_tensor_type + + @extclass(T.Tensor) @extclass(T.Window) @extclass(T.Num) diff --git a/src/exo/core/LoopIR_pprint.py b/src/exo/core/LoopIR_pprint.py index 4464976e3..e087897a7 100644 --- a/src/exo/core/LoopIR_pprint.py +++ b/src/exo/core/LoopIR_pprint.py @@ -428,6 +428,8 @@ def _print_stmt(stmt, env: PrintEnv, indent: str) -> list[str]: elif isinstance(stmt, LoopIR.WindowStmt): rhs = _print_expr(stmt.rhs, env) + if stmt.special_window is not None: + rhs = f"{rhs} @ {stmt.special_window.name()}" return [f"{indent}{env.get_name(stmt.name)} = {rhs}"] elif isinstance(stmt, LoopIR.Alloc): diff --git a/src/exo/core/memory.py b/src/exo/core/memory.py index c5d1a1770..08465f13c 100644 --- a/src/exo/core/memory.py +++ b/src/exo/core/memory.py @@ -16,7 +16,10 @@ # * write to the memory (optional) # * reduce to the memory (optional) """ +from __future__ import annotations + from abc import ABC, abstractmethod +from typing import Optional """ --- Alloc specifications --- @@ -54,11 +57,18 @@ """ +_memwin_template_names = {} +_memwin_template_cache = {} + + class MemGenError(Exception): pass -def generate_offset(indices, strides): +def generate_offset(indices, strides, vector_size=1): + assert isinstance(vector_size, int), "generalize this if needed" + assert vector_size >= 1 + def index_expr(i, s): if s == "0" or i == "0": return "" @@ -75,13 +85,105 @@ def index_expr(i, s): exprs = [e for i, s in zip(indices, strides) if (e := index_expr(i, s)) != ""] - return " + ".join(exprs) if len(exprs) > 0 else "0" + expr = " + ".join(exprs) if len(exprs) > 0 else "0" + if vector_size != 1 and expr != "0": + expr = f"({expr}) / {vector_size}" + + return expr + + +class WindowStructCtx(object): + __slots__ = [ + "_ctype", + "_type_shorthand", + "_n_dims", + "_is_const", + "_separate_dataptr", + "_srcinfo", + "_struct_name", + "_guard_macro", + ] + + def __init__( + self, ctype, type_shorthand, n_dims, is_const, separate_dataptr, srcinfo + ): + """For internal use of LoopIR compiler""" + self._ctype = ctype + self._type_shorthand = type_shorthand + self._n_dims = n_dims + self._is_const = is_const + self._separate_dataptr = separate_dataptr + self._srcinfo = srcinfo + + self._struct_name = None + self._guard_macro = None + + def generate_default(self, memwin_name, data_ctype=None): + sname = self.struct_name(memwin_name) + if data_ctype is None: + data_ctype = self._ctype + # Spacing difference gives byte-for-byte compatibility with Exo 1. + struct_cptr = "const " * self._is_const + data_ctype + " *" + dataptr_ctype = "const " * self._is_const + data_ctype + "*" + + sdef = ( + f"struct {sname}{{\n" + f" {struct_cptr} const data;\n" + f" const int_fast32_t strides[{self._n_dims}];\n" + f"}};" + ) + return dataptr_ctype, sdef + + def struct_name(self, memwin_name: str, mangle_parameters=None) -> str: + """Must be called at least once (and consistently) to name the struct.""" + assert isinstance(memwin_name, str), "use str (avoid silent mistakes)" + assert memwin_name + + if mangle_parameters: + for p in mangle_parameters: + assert isinstance(p, int), "Only support mangled names for ints" + if p >= 0: + memwin_name += f"_{p}" + else: + memwin_name += f"_n{-p}" + + # As promised in MemWin.separate_dataptr, if True, disable const suffix + const_suffix = "c" if self._is_const and not self._separate_dataptr else "" + base_sname = f"exo_win_{self._n_dims}{self._type_shorthand}{const_suffix}" + mem_suffix = "" if memwin_name == "DRAM" else "_" + memwin_name + sname = base_sname + mem_suffix + + assert self._struct_name is None or self.struct_name == sname + self._struct_name = sname + self._guard_macro = base_sname.upper() + mem_suffix # case-sensitive + + return sname + def n_dims(self) -> int: + return self._n_dims + + def is_const(self) -> bool: + return self._is_const + + def ctype(self) -> str: + """return C name for scalar type tensor is made of e.g. float, uint16_t""" + return self._ctype + + def type_shorthand(self) -> str: + """e.g. f32, u16""" + return self._type_shorthand + + def srcinfo(self): + """Convert to str and include in error messages""" + return self._srcinfo + + +class MemWin(ABC): + """Common base class of allocable Memory and non-allocable SpecialWindow""" -class Memory(ABC): @classmethod def name(cls): - return cls.__name__ + return _memwin_template_names.get(cls) or cls.__name__ @classmethod def global_(cls): @@ -92,25 +194,76 @@ def global_(cls): @classmethod @abstractmethod - def alloc(cls, new_name, prim_type, shape, srcinfo): + def window_definition(cls, ctx: WindowStructCtx): """ - python gemmini_extended_compute_preloaded + C code defining struct. + Get the required parameters from the WindowStructCtx. + Return (dataptr : str, window_struct : str) + + dataptr: C type for a raw pointer (e.g. __m256d*, float*) + + window_struct: C code defining a struct named ctx.struct_name() + + The compiler will include a header guard for you. """ raise NotImplementedError() @classmethod - @abstractmethod - def free(cls, new_name, prim_type, shape, srcinfo): - raise NotImplementedError() + def separate_dataptr(cls): + """separate_dataptr: return False for the usual case. + + If True, the window is passed to functions as separate arguments + (dataptr, window_struct) rather than a combined window struct; + the window struct only contains layout information in this case, + and you must define this custom layout (see window(...)) + + In this case, the layout-only window struct is the same for both + const and non-const windows. + """ + return False @classmethod - def window(cls, basetyp, baseptr, indices, strides, srcinfo): - offset = generate_offset(indices, strides) + def window(cls, basetyp, in_expr, indices, strides, srcinfo) -> str: + """ + Return one of the following: + + Base case: data : str + Custom layout: (dataptr : str, layout : str) - if basetyp.is_win(): - baseptr = f"{baseptr}.data" + Where dataptr and layout are both C strings used to initialize + the window struct. (A default layout is provided in non-custom cases). + We implicitly take dataptr = &data in the base case. - return f"{baseptr}[{offset}]" + If you wish to implement can_read/write/reduce, you should not use + a custom layout. Furthermore, currently custom layouts don't support + reducing the dimensionality of a window (can be changed later). + + basetyp: LoopIR.Tensor instance + + in_expr: C expression of the following type: + + basetyp.is_win() = false: dense tensor type (as generated by alloc) + Won't occur if implementing a SpecialWindow + + basetyp.is_win() = True: window type + str if no separate_dataptr, else (dataptr : str, layout : str) + + indices: C expressions of indices (offsets per dimension) + e.g. [1:10, 42:46] -> ["1", "42"] (we don't provide the slice sizes) + + strides: C expressions of per-dim strides, in units of scalars. (*) + (*) consider passing vector_size to generate_offset. + + srcinfo: include this when throwing an exception. + """ + return cls.default_window(1, basetyp, in_expr, indices, strides, srcinfo) + + @classmethod + def default_window(cls, vector_size, basetyp, in_expr, indices, strides, srcinfo): + """Helper for simple window(...) implementations. Don't override this""" + offset = generate_offset(indices, strides, vector_size) + dataptr = f"{in_expr}.data" if basetyp.is_win() else in_expr + return f"{dataptr}[{offset}]" @classmethod @abstractmethod @@ -132,13 +285,182 @@ def reduce(cls, s, lhs, rhs): ) +class Memory(MemWin): + @classmethod + @abstractmethod + def alloc(cls, new_name, prim_type, shape, srcinfo): + """ + python gemmini_extended_compute_preloaded + """ + raise NotImplementedError() + + @classmethod + @abstractmethod + def free(cls, new_name, prim_type, shape, srcinfo): + raise NotImplementedError() + + @classmethod + def window_definition(cls, ctx: WindowStructCtx): + """This is not correct for non-scalar cases but we provide this + for backwards compatibility with Exo 1 ... programs worked OK + if they never materialized the faulty default window struct""" + return ctx.generate_default("DRAM") + + +class SpecialWindowFromMemoryCtx(object): + # TODO since we only give access to runtime window struct, + # it's currently not possible to compile-time assert stride info. + __slots__ = [ + "_src_data", + "_src_layout", + "_dst_dataptr_ctype", + "_dst_struct_name", + "_tensor_type", + "_shape_strs", + "_is_const", + "_ctype", + "_type_shorthand", + "_srcinfo", + ] + + def __init__( + self, + src_data, + src_layout, + dst_dataptr_ctype, + dst_struct_name, + tensor_type, + shape_strs, + is_const, + ctype, + type_shorthand, + srcinfo, + ): + """For internal use of LoopIR compiler""" + self._src_data = src_data + self._src_layout = src_layout + self._dst_dataptr_ctype = dst_dataptr_ctype + self._dst_struct_name = dst_struct_name + self._tensor_type = tensor_type + self._shape_strs = shape_strs + self._is_const = is_const + self._ctype = ctype + self._type_shorthand = type_shorthand + self._srcinfo = srcinfo + + def src_data(self): + """C initializer for source window data pointer + + Passed through from Memory.window of the source memory type""" + return self._src_data + + def src_layout(self): + """Untyped C initializer for source window layout (e.g. strides) + + Passed through (or default strides) from Memory.window of + the source memory type""" + return self._src_layout + + def dst_dataptr_ctype(self): + """C type name of SpecialWindow data pointer (you defined this)""" + return self._dst_dataptr_ctype + + def dst_struct_name(self): + """C struct name of SpecialWindow window struct (you defined this)""" + return self._dst_struct_name + + def tensor_type(self) -> LoopIR.Tensor: + """return LoopIR.Tensor type of input tensor""" + assert isinstance(self._tensor_type, LoopIR.Tensor) + return self._tensor_type + + def shape_strs(self): + """C strings defining dimension sizes of window""" + return self._shape_strs + + def is_const(self) -> bool: + return self._is_const + + def ctype(self) -> str: + """return C name for scalar type tensor is made of e.g. float, uint16_t""" + return self._ctype + + def type_shorthand(self) -> str: + """e.g. f32, u16""" + return self._type_shorthand + + def srcinfo(self): + """Convert to str and include in error messages""" + return self._srcinfo + + +class SpecialWindow(MemWin): + @classmethod + @abstractmethod + def source_memory_type(cls) -> type: + """Return memory type expected as input to window statement""" + raise NotImplementedError() + + @classmethod + @abstractmethod + def from_memory(cls, ctx: SpecialWindowFromMemoryCtx): + """Callback for generating C code initializing a special window + from a window to a tensor of the source memory type. + + If separate_dataptr(), return (dataptr : str, layout : str) of + C expressions that can initialize the two respective window variables. + Otherwise, return a single C expression that can be used + to initialize a struct of the window type. + """ + raise NotImplementedError() + + # Remember to implement everything in base class MemWin as well + + +# ----------- TEMPLATE SYSTEM ------------- + + +def memwin_template(class_factory): + """Wrapper for creating MemWin types parameterized on a tuple of args. + + The name of the generated class will look like a function call + e.g. MyMemoryName(64, 128) [akin to MyMemoryName<64, 128> in C++]. + Cached: identically parameterized MemWins will be identical Python types. + + The parameter tuple is injected to the class as memwin_template_parameters + + Usage: + + @memwin_template + def MyMemoryName(*parameters): + class MemoryImpl(Memory): # class name is ignored + ...implement memory normally + return MemoryImpl + """ + + def class_factory_wrapper(*parameters, **kwargs): + assert not kwargs, "No support for keyword template parameters" + cache_key = (id(class_factory), parameters) + cls = _memwin_template_cache.get(cache_key) + if not cls: + cls = class_factory(*parameters) + cls_name = f"{class_factory.__name__}{parameters}" + _memwin_template_cache[cache_key] = cls + _memwin_template_names[cls] = cls_name + assert not hasattr(cls, "memwin_template_parameters") + cls.memwin_template_parameters = parameters + return cls + + return class_factory_wrapper + + # ----------- DRAM on LINUX ---------------- class DRAM(Memory): @classmethod def global_(cls): - return "#include \n" "#include \n" + return "#include \n#include \n" @classmethod def alloc(cls, new_name, prim_type, shape, srcinfo): diff --git a/src/exo/frontend/pyparser.py b/src/exo/frontend/pyparser.py index f85c928ff..53468742a 100644 --- a/src/exo/frontend/pyparser.py +++ b/src/exo/frontend/pyparser.py @@ -14,6 +14,7 @@ from ..core.LoopIR import UAST, PAST, front_ops from ..core.prelude import * from ..core.extern import Extern +from ..core.memory import MemWin, Memory, SpecialWindow from typing import Any, Callable, Union, NoReturn, Optional import copy @@ -919,7 +920,15 @@ def is_at(x): else: typ = self.parse_num_type(typ_node, is_arg=True) - mem = self.eval_expr(mem_node) if mem_node else None + if mem_node: + mem = self.eval_expr(mem_node) + if not isinstance(mem, type) or not issubclass(mem, MemWin): + self.err( + node, + "annotation needs to be subclass of Memory or SpecialWindow", + ) + else: + mem = None return typ, mem @@ -929,6 +938,8 @@ def parse_alloc_typmem(self, node): # x[n] @ DRAM # x[n] @ lib.scratch mem = self.eval_expr(node.right) + if not isinstance(mem, type) or not issubclass(mem, Memory): + self.err(node, "expected @mem with mem a subclass of Memory") node = node.left else: mem = None @@ -1555,7 +1566,8 @@ def parse_expr(self, e): ) if is_window: - return UAST.WindowExpr(nm, idxs, self.getsrcinfo(e)) + # SpecialWindow handled by BinOp parser + return UAST.WindowExpr(nm, idxs, None, self.getsrcinfo(e)) else: return UAST.Read(nm, idxs, self.getsrcinfo(e)) @@ -1600,6 +1612,20 @@ def parse_expr(self, e): elif isinstance(e, pyast.BinOp): lhs = self.parse_expr(e.left) + + # tensor[idxs...] @ SpecialWindow + if ( + isinstance(e.op, pyast.MatMult) + and hasattr(self.AST, "WindowExpr") + and isinstance(lhs, self.AST.WindowExpr) + ): + special_window = self.eval_expr(e.right) + if not isinstance(special_window, type) or not issubclass( + special_window, SpecialWindow + ): + self.err(e, "expected @win with win a subclass of SpecialWindow") + return lhs.update(special_window=special_window) + rhs = self.parse_expr(e.right) if isinstance(e.op, pyast.Add): op = "+" diff --git a/src/exo/frontend/typecheck.py b/src/exo/frontend/typecheck.py index 9c89cb7e9..9b4ddd7b2 100644 --- a/src/exo/frontend/typecheck.py +++ b/src/exo/frontend/typecheck.py @@ -209,7 +209,7 @@ def check_access(self, node, nm, idx, lvalue=False): def check_single_stmt(self, stmt): if isinstance(stmt, UAST.FreshAssign): - rhs = self.check_e(stmt.rhs) + rhs = self.check_e(stmt.rhs, allow_special_window=True) # We see a statement of the form # nm = ... @@ -223,7 +223,11 @@ def check_single_stmt(self, stmt): elif isinstance(rhs.type, T.Window): assert isinstance(rhs, LoopIR.WindowExpr) self.env[stmt.name] = rhs.type - return [LoopIR.WindowStmt(stmt.name, rhs, stmt.srcinfo)] + return [ + LoopIR.WindowStmt( + stmt.name, rhs, stmt.rhs.special_window, stmt.srcinfo + ) + ] else: self.err( stmt, @@ -377,7 +381,7 @@ def check_w_access(self, e, orig_hi): return LoopIR.Interval(lo, hi, e.srcinfo) - def check_e(self, e, is_index=False): + def check_e(self, e, is_index=False, allow_special_window=False): if isinstance(e, UAST.Read): typ = self.env[e.name] # if we only partially accessed the base tensor/window, @@ -391,7 +395,7 @@ def check_e(self, e, is_index=False): for _ in range(0, len(typ.shape()) - len(e.idx)) ] - desugared = UAST.WindowExpr(e.name, idxs, e.srcinfo) + desugared = UAST.WindowExpr(e.name, idxs, None, e.srcinfo) return self.check_e(desugared) # otherwise, we have a normal access @@ -409,6 +413,24 @@ def check_e(self, e, is_index=False): ) return LoopIR.WindowExpr(e.name, [], T.err, e.srcinfo) + if e.special_window is not None: + # UAST has the optional special window as part of WindowExpr as + # that's how it parses, but LoopIR has the special window as + # part of WindowStmt since that matches the usage pattern + # (can't construct special windows just anywhere) + if not allow_special_window: + self.err( + e, + f"Can only create SpecialWindow as part of " + f"WindowStmt (W = t[idx...] @ SpecialWindow)", + ) + elif not in_typ.is_dense_tensor(): + self.err( + e, + "Can only create SpecialWindow from a dense " + "tensor, not another window", + ) + in_shape = in_typ.shape() if len(in_shape) != len(e.idx): self.err( diff --git a/src/exo/libs/memories.py b/src/exo/libs/memories.py index 8ce892997..3244d9e8a 100644 --- a/src/exo/libs/memories.py +++ b/src/exo/libs/memories.py @@ -188,21 +188,32 @@ def window(cls, basetyp, baseptr, indices, strides, srcinfo): class AVX2(Memory): + _vec_types = { + "float": (8, "__m256"), + "double": (4, "__m256d"), + "uint16_t": (16, "__m256i"), + } + @classmethod def global_(cls): return "#include " + @classmethod + def window_definition(cls, ctx): + if ctx.n_dims() != 1: + raise MemGenError( + f"{ctx.srcinfo()}: Only support windows to a single AVX vector (n_dims 1)" + ) + _, c_vec = cls._vec_types[ctx.ctype()] + return ctx.generate_default("AVX2", c_vec) + @classmethod def alloc(cls, new_name, prim_type, shape, srcinfo): + vec_types = cls._vec_types + if not shape: raise MemGenError(f"{srcinfo}: AVX2 vectors are not scalar values") - vec_types = { - "float": (8, "__m256"), - "double": (4, "__m256d"), - "uint16_t": (16, "__m256i"), - } - if not prim_type in vec_types.keys(): raise MemGenError( f"{srcinfo}: AVX2 vectors must be f32/f64/ui16 (for now), got {prim_type}" @@ -230,6 +241,8 @@ def free(cls, new_name, prim_type, shape, srcinfo): @classmethod def window(cls, basetyp, baseptr, indices, strides, srcinfo): + if basetyp.is_win(): + return f"*{baseptr}.data" assert strides[-1] == "1" idxs = indices[:-1] or "" if idxs: @@ -245,6 +258,14 @@ class AVX512(Memory): def global_(cls): return "#include " + @classmethod + def window_definition(cls, ctx): + if ctx.n_dims() != 1: + raise MemGenError( + f"{ctx.srcinfo()}: Only support windows to a single AVX vector (n_dims 1)" + ) + return ctx.generate_default("AVX512", "__m512") + @classmethod def can_read(cls): return False @@ -270,6 +291,8 @@ def free(cls, new_name, prim_type, shape, srcinfo): @classmethod def window(cls, basetyp, baseptr, indices, strides, srcinfo): + if basetyp.is_win(): + return f"*{baseptr}.data" assert strides[-1] == "1" idxs = indices[:-1] or "" if idxs: diff --git a/src/exo/rewrite/LoopIR_scheduling.py b/src/exo/rewrite/LoopIR_scheduling.py index c67f0ec70..ebac7ee8f 100644 --- a/src/exo/rewrite/LoopIR_scheduling.py +++ b/src/exo/rewrite/LoopIR_scheduling.py @@ -901,7 +901,7 @@ def DoInline(call): def map_bind(nm, a): if isinstance(a, LoopIR.WindowExpr): - stmt = LoopIR.WindowStmt(nm, a, a.srcinfo) + stmt = LoopIR.WindowStmt(nm, a, None, a.srcinfo) win_binds.append(stmt) return LoopIR.Read(nm, [], a.type, a.srcinfo) return a diff --git a/tests/golden/test_apps/test_blur.txt b/tests/golden/test_apps/test_blur.txt index bc2e5b726..171d8f8ec 100644 --- a/tests/golden/test_apps/test_blur.txt +++ b/tests/golden/test_apps/test_blur.txt @@ -30,6 +30,9 @@ extern "C" { #endif +#include +#include + #ifndef EXO_WIN_1UI16 #define EXO_WIN_1UI16 struct exo_win_1ui16{ @@ -73,9 +76,20 @@ void exo_blur_halide( void *ctxt, int_fast32_t W, int_fast32_t H, uint16_t* blur #include #include -#include -#include - +#ifndef EXO_WIN_1UI16_AVX2 +#define EXO_WIN_1UI16_AVX2 +struct exo_win_1ui16_AVX2{ + __m256i * const data; + const int_fast32_t strides[1]; +}; +#endif +#ifndef EXO_WIN_1UI16C_AVX2 +#define EXO_WIN_1UI16C_AVX2 +struct exo_win_1ui16c_AVX2{ + const __m256i * const data; + const int_fast32_t strides[1]; +}; +#endif /* relying on the following instruction..." avx2_ui16_divide_by_3(out,x) diff --git a/tests/golden/test_apps/test_gemmini_conv.txt b/tests/golden/test_apps/test_gemmini_conv.txt index c50409bde..a6d7cd79e 100644 --- a/tests/golden/test_apps/test_gemmini_conv.txt +++ b/tests/golden/test_apps/test_gemmini_conv.txt @@ -50,13 +50,9 @@ typedef struct test_case_Context { } ConfigStore; } test_case_Context; -#ifndef EXO_WIN_2I32 -#define EXO_WIN_2I32 -struct exo_win_2i32{ - int32_t * const data; - const int_fast32_t strides[2]; -}; -#endif +#include +#include + #ifndef EXO_WIN_2I32C #define EXO_WIN_2I32C struct exo_win_2i32c{ @@ -78,13 +74,6 @@ struct exo_win_2i8c{ const int_fast32_t strides[2]; }; #endif -#ifndef EXO_WIN_3I8 -#define EXO_WIN_3I8 -struct exo_win_3i8{ - int8_t * const data; - const int_fast32_t strides[3]; -}; -#endif // conv_17( // output : i8[4, 28, 28, 128] @DRAM, // bias : i32[1, 128] @DRAM, @@ -154,13 +143,24 @@ void conv_3_cpu( test_case_Context *ctxt, int8_t* output, const int32_t* bias, c #include "test_case.h" -#include -#include - #include #include "gemm_acc_malloc.h" #include #include "gemm_malloc.h" +#ifndef EXO_WIN_2I32 +#define EXO_WIN_2I32 +struct exo_win_2i32{ + int32_t * const data; + const int_fast32_t strides[2]; +}; +#endif +#ifndef EXO_WIN_3I8 +#define EXO_WIN_3I8 +struct exo_win_3i8{ + int8_t * const data; + const int_fast32_t strides[3]; +}; +#endif int8_t _relu_int8_t(int8_t x) { if (x > 0.0) return x; else return 0.0; diff --git a/tests/golden/test_apps/test_gemmini_matmul.txt b/tests/golden/test_apps/test_gemmini_matmul.txt index 89a5be56e..2afb4ce60 100644 --- a/tests/golden/test_apps/test_gemmini_matmul.txt +++ b/tests/golden/test_apps/test_gemmini_matmul.txt @@ -54,20 +54,9 @@ typedef struct test_case_Context { } ConfigStore; } test_case_Context; -#ifndef EXO_WIN_2I32 -#define EXO_WIN_2I32 -struct exo_win_2i32{ - int32_t * const data; - const int_fast32_t strides[2]; -}; -#endif -#ifndef EXO_WIN_2I32C -#define EXO_WIN_2I32C -struct exo_win_2i32c{ - const int32_t * const data; - const int_fast32_t strides[2]; -}; -#endif +#include +#include + #ifndef EXO_WIN_2I8 #define EXO_WIN_2I8 struct exo_win_2i8{ @@ -82,13 +71,6 @@ struct exo_win_2i8c{ const int_fast32_t strides[2]; }; #endif -#ifndef EXO_WIN_3I8 -#define EXO_WIN_3I8 -struct exo_win_3i8{ - int8_t * const data; - const int_fast32_t strides[3]; -}; -#endif // cpu_matmul_14( // scale : f32 @DRAM, // act : bool, @@ -206,13 +188,31 @@ void matmul_6( test_case_Context *ctxt, const float* scale, bool act, const int8 #include "test_case.h" -#include -#include - #include #include "gemm_acc_malloc.h" #include #include "gemm_malloc.h" +#ifndef EXO_WIN_2I32 +#define EXO_WIN_2I32 +struct exo_win_2i32{ + int32_t * const data; + const int_fast32_t strides[2]; +}; +#endif +#ifndef EXO_WIN_2I32C +#define EXO_WIN_2I32C +struct exo_win_2i32c{ + const int32_t * const data; + const int_fast32_t strides[2]; +}; +#endif +#ifndef EXO_WIN_3I8 +#define EXO_WIN_3I8 +struct exo_win_3i8{ + int8_t * const data; + const int_fast32_t strides[3]; +}; +#endif int8_t _relu_int8_t(int8_t x) { if (x > 0.0) return x; else return 0.0; diff --git a/tests/golden/test_apps/test_neon_sgemm.txt b/tests/golden/test_apps/test_neon_sgemm.txt index b91b0c0a2..edcb269e4 100644 --- a/tests/golden/test_apps/test_neon_sgemm.txt +++ b/tests/golden/test_apps/test_neon_sgemm.txt @@ -30,6 +30,9 @@ extern "C" { #endif +#include +#include + #ifndef EXO_WIN_1F32 #define EXO_WIN_1F32 struct exo_win_1f32{ @@ -77,9 +80,6 @@ void sgemm_exo( void *ctxt, int_fast32_t M, int_fast32_t N, int_fast32_t K, cons #include "test_case.h" -#include -#include - #include // neon_microkernel( // K : size, diff --git a/tests/golden/test_apps/test_unsharp.txt b/tests/golden/test_apps/test_unsharp.txt index a8c93077c..ed44b3f8f 100644 --- a/tests/golden/test_apps/test_unsharp.txt +++ b/tests/golden/test_apps/test_unsharp.txt @@ -30,6 +30,9 @@ extern "C" { #endif +#include +#include + #ifndef EXO_WIN_1F32 #define EXO_WIN_1F32 struct exo_win_1f32{ @@ -81,9 +84,20 @@ void exo_unsharp_vectorized( void *ctxt, int_fast32_t W, int_fast32_t H, float* #include #include -#include -#include - +#ifndef EXO_WIN_1F32_AVX2 +#define EXO_WIN_1F32_AVX2 +struct exo_win_1f32_AVX2{ + __m256 * const data; + const int_fast32_t strides[1]; +}; +#endif +#ifndef EXO_WIN_1F32C_AVX2 +#define EXO_WIN_1F32C_AVX2 +struct exo_win_1f32c_AVX2{ + const __m256 * const data; + const int_fast32_t strides[1]; +}; +#endif // exo_unsharp( // W : size, // H : size, diff --git a/tests/golden/test_apps/test_x86_conv.txt b/tests/golden/test_apps/test_x86_conv.txt index 7070eca2d..1d3c61a25 100644 --- a/tests/golden/test_apps/test_x86_conv.txt +++ b/tests/golden/test_apps/test_x86_conv.txt @@ -30,6 +30,9 @@ extern "C" { #endif +#include +#include + #ifndef EXO_WIN_1F32 #define EXO_WIN_1F32 struct exo_win_1f32{ @@ -62,9 +65,20 @@ void conv_specialized( void *ctxt, const float* inp, float* output, const float* #include "test_case.h" #include -#include -#include - +#ifndef EXO_WIN_1F32_AVX512 +#define EXO_WIN_1F32_AVX512 +struct exo_win_1f32_AVX512{ + __m512 * const data; + const int_fast32_t strides[1]; +}; +#endif +#ifndef EXO_WIN_1F32C_AVX512 +#define EXO_WIN_1F32C_AVX512 +struct exo_win_1f32c_AVX512{ + const __m512 * const data; + const int_fast32_t strides[1]; +}; +#endif // conv_specialized( // inp : f32[5, 82, 102, 128] @DRAM, // output : f32[5, 80, 100, 128] @DRAM, diff --git a/tests/golden/test_apps/test_x86_sgemm.txt b/tests/golden/test_apps/test_x86_sgemm.txt index 37b9d3259..d724a27af 100644 --- a/tests/golden/test_apps/test_x86_sgemm.txt +++ b/tests/golden/test_apps/test_x86_sgemm.txt @@ -30,6 +30,9 @@ extern "C" { #endif +#include +#include + #ifndef EXO_WIN_1F32 #define EXO_WIN_1F32 struct exo_win_1f32{ @@ -81,9 +84,20 @@ void sgemm_exo( void *ctxt, int_fast32_t M, int_fast32_t N, int_fast32_t K, cons #include #include -#include -#include - +#ifndef EXO_WIN_1F32_AVX512 +#define EXO_WIN_1F32_AVX512 +struct exo_win_1f32_AVX512{ + __m512 * const data; + const int_fast32_t strides[1]; +}; +#endif +#ifndef EXO_WIN_1F32C_AVX512 +#define EXO_WIN_1F32C_AVX512 +struct exo_win_1f32c_AVX512{ + const __m512 * const data; + const int_fast32_t strides[1]; +}; +#endif // basic_kernel_1x4( // K : size, // A : [f32][1, K] @DRAM, diff --git a/tests/golden/test_codegen/test_CIR_USub.txt b/tests/golden/test_codegen/test_CIR_USub.txt index 0adfdae41..161a4329b 100644 --- a/tests/golden/test_codegen/test_CIR_USub.txt +++ b/tests/golden/test_codegen/test_CIR_USub.txt @@ -30,6 +30,9 @@ extern "C" { #endif +#include +#include + // foo( // N : size, @@ -46,9 +49,6 @@ void foo( void *ctxt, int_fast32_t N, float* x ); #include "test.h" -#include -#include - // foo( // N : size, // x : f32[N] @DRAM diff --git a/tests/golden/test_codegen/test_const_buffer_parameters.txt b/tests/golden/test_codegen/test_const_buffer_parameters.txt index de6aee53b..67190d431 100644 --- a/tests/golden/test_codegen/test_const_buffer_parameters.txt +++ b/tests/golden/test_codegen/test_const_buffer_parameters.txt @@ -30,6 +30,9 @@ extern "C" { #endif +#include +#include + #ifndef EXO_WIN_1F32 #define EXO_WIN_1F32 struct exo_win_1f32{ @@ -74,9 +77,6 @@ void memcpy_b( void *ctxt, int_fast32_t N, float* A, struct exo_win_1f32c B ); #include "test.h" -#include -#include - // memcpy( // N : size, // A : f32[N] @DRAM, diff --git a/tests/golden/test_codegen/test_const_local_buffer.txt b/tests/golden/test_codegen/test_const_local_buffer.txt index 7ddbfd275..e8e079ec8 100644 --- a/tests/golden/test_codegen/test_const_local_buffer.txt +++ b/tests/golden/test_codegen/test_const_local_buffer.txt @@ -30,6 +30,9 @@ extern "C" { #endif +#include +#include + #ifndef EXO_WIN_1F32 #define EXO_WIN_1F32 struct exo_win_1f32{ @@ -50,9 +53,6 @@ void caller( void *ctxt ); #endif // TEST_H #include "test.h" -#include -#include - // callee( // N : size, // A : [f32][N] @DRAM diff --git a/tests/golden/test_codegen/test_const_local_window.txt b/tests/golden/test_codegen/test_const_local_window.txt index ab101aca4..8f0dccd19 100644 --- a/tests/golden/test_codegen/test_const_local_window.txt +++ b/tests/golden/test_codegen/test_const_local_window.txt @@ -30,6 +30,9 @@ extern "C" { #endif +#include +#include + #ifndef EXO_WIN_1F32 #define EXO_WIN_1F32 struct exo_win_1f32{ @@ -50,9 +53,6 @@ void caller( void *ctxt ); #endif // TEST_H #include "test.h" -#include -#include - // callee( // N : size, // A : [f32][N] @DRAM diff --git a/tests/golden/test_codegen/test_memcpy_instr.txt b/tests/golden/test_codegen/test_memcpy_instr.txt index 2ff0f5c37..e21b4be07 100644 --- a/tests/golden/test_codegen/test_memcpy_instr.txt +++ b/tests/golden/test_codegen/test_memcpy_instr.txt @@ -1,9 +1,6 @@ #include "bar.h" #include -#include -#include - // bar( // n : size, // dst : f32[n] @DRAM, @@ -51,6 +48,9 @@ extern "C" { #endif +#include +#include + // bar( // n : size, diff --git a/tests/golden/test_codegen/test_no_exo_floor_div_after_divide_loop_with_guard.txt b/tests/golden/test_codegen/test_no_exo_floor_div_after_divide_loop_with_guard.txt index 9251b833b..62bc874d9 100644 --- a/tests/golden/test_codegen/test_no_exo_floor_div_after_divide_loop_with_guard.txt +++ b/tests/golden/test_codegen/test_no_exo_floor_div_after_divide_loop_with_guard.txt @@ -30,6 +30,9 @@ extern "C" { #endif +#include +#include + // foo( // N : size, @@ -46,9 +49,6 @@ void foo( void *ctxt, int_fast32_t N, float* x ); #include "test.h" -#include -#include - // foo( // N : size, // x : f32[N] @DRAM @@ -96,6 +96,9 @@ extern "C" { #endif +#include +#include + // foo( // N : size, @@ -112,9 +115,6 @@ void foo( void *ctxt, int_fast32_t N, float* x ); #include "test.h" -#include -#include - // foo( // N : size, // x : f32[N] @DRAM @@ -169,6 +169,9 @@ extern "C" { #endif +#include +#include + // foo( // N : size, @@ -185,9 +188,6 @@ void foo( void *ctxt, int_fast32_t N, float* x ); #include "test.h" -#include -#include - // foo( // N : size, // x : f32[N] @DRAM diff --git a/tests/golden/test_codegen/test_no_exo_floor_div_triangular_access.txt b/tests/golden/test_codegen/test_no_exo_floor_div_triangular_access.txt index 9b437b751..a8a4d1ef1 100644 --- a/tests/golden/test_codegen/test_no_exo_floor_div_triangular_access.txt +++ b/tests/golden/test_codegen/test_no_exo_floor_div_triangular_access.txt @@ -30,6 +30,9 @@ extern "C" { #endif +#include +#include + // foo( // N : size, @@ -46,9 +49,6 @@ void foo( void *ctxt, int_fast32_t N, float* x ); #include "test.h" -#include -#include - // foo( // N : size, // x : f32[N, N] @DRAM diff --git a/tests/golden/test_codegen/test_pragma_parallel_loop.txt b/tests/golden/test_codegen/test_pragma_parallel_loop.txt index caeca75de..ecbe40aa5 100644 --- a/tests/golden/test_codegen/test_pragma_parallel_loop.txt +++ b/tests/golden/test_codegen/test_pragma_parallel_loop.txt @@ -3,9 +3,6 @@ #include #include -#include -#include - // foo( // x : i8[10] @DRAM // ) diff --git a/tests/golden/test_codegen/test_target_another_exo_library.txt b/tests/golden/test_codegen/test_target_another_exo_library.txt index e31b8d88f..60fbfabd3 100644 --- a/tests/golden/test_codegen/test_target_another_exo_library.txt +++ b/tests/golden/test_codegen/test_target_another_exo_library.txt @@ -30,6 +30,9 @@ extern "C" { #endif +#include +#include + // foo( // n : size, @@ -46,9 +49,6 @@ void foo( void *ctxt, int_fast32_t n, float* x ); #include "foo.h" -#include -#include - // foo( // n : size, // x : f32[n] @DRAM @@ -97,6 +97,9 @@ extern "C" { #endif +#include +#include + // bar( // n : size, @@ -114,9 +117,6 @@ void bar( void *ctxt, int_fast32_t n, float* y ); #include "bar.h" #include "foo.h" -#include -#include - // bar( // n : size, // y : f32[n] @DRAM diff --git a/tests/golden/test_examples/test_avx2_matmul.txt b/tests/golden/test_examples/test_avx2_matmul.txt index 2e16ce25f..7a4f36443 100644 --- a/tests/golden/test_examples/test_avx2_matmul.txt +++ b/tests/golden/test_examples/test_avx2_matmul.txt @@ -30,6 +30,9 @@ extern "C" { #endif +#include +#include + #ifndef EXO_WIN_1F32 #define EXO_WIN_1F32 struct exo_win_1f32{ @@ -70,9 +73,20 @@ void rank_k_reduce_6x16_scheduled( void *ctxt, int_fast32_t K, const float* A, c #include "test_case.h" #include -#include -#include - +#ifndef EXO_WIN_1F32_AVX2 +#define EXO_WIN_1F32_AVX2 +struct exo_win_1f32_AVX2{ + __m256 * const data; + const int_fast32_t strides[1]; +}; +#endif +#ifndef EXO_WIN_1F32C_AVX2 +#define EXO_WIN_1F32C_AVX2 +struct exo_win_1f32c_AVX2{ + const __m256 * const data; + const int_fast32_t strides[1]; +}; +#endif /* relying on the following instruction..." mm256_broadcast_ss(out,val) diff --git a/tests/golden/test_examples/test_cursors.txt b/tests/golden/test_examples/test_cursors.txt index 521ebb6bc..dbbcbba3c 100644 --- a/tests/golden/test_examples/test_cursors.txt +++ b/tests/golden/test_examples/test_cursors.txt @@ -30,6 +30,9 @@ extern "C" { #endif +#include +#include + // gemv( // M : size, @@ -49,9 +52,6 @@ void gemv( void *ctxt, int_fast32_t M, int_fast32_t N, const float* A, const flo #include "test_case.h" -#include -#include - // gemv( // M : size, // N : size, diff --git a/tests/golden/test_examples/test_quiz1.txt b/tests/golden/test_examples/test_quiz1.txt index 27e523b1f..4267a3afd 100644 --- a/tests/golden/test_examples/test_quiz1.txt +++ b/tests/golden/test_examples/test_quiz1.txt @@ -30,6 +30,9 @@ extern "C" { #endif +#include +#include + // vec_double( // N : size, @@ -54,9 +57,6 @@ void vec_double_optimized( void *ctxt, int_fast32_t N, const float* inp, float* #include "test_case.h" -#include -#include - // vec_double( // N : size, // inp : f32[N] @DRAM, diff --git a/tests/golden/test_examples/test_quiz3.txt b/tests/golden/test_examples/test_quiz3.txt index 78f571219..631907aeb 100644 --- a/tests/golden/test_examples/test_quiz3.txt +++ b/tests/golden/test_examples/test_quiz3.txt @@ -30,6 +30,9 @@ extern "C" { #endif +#include +#include + // tile_and_fused_blur( // W : size, @@ -56,9 +59,6 @@ void tile_and_fused_blur_scheduled( void *ctxt, int_fast32_t W, int_fast32_t H, #include "test_case.h" -#include -#include - // tile_and_fused_blur( // W : size, // H : size, diff --git a/tests/golden/test_examples/test_rvm_conv1d.txt b/tests/golden/test_examples/test_rvm_conv1d.txt index f67e7baa9..fe7129257 100644 --- a/tests/golden/test_examples/test_rvm_conv1d.txt +++ b/tests/golden/test_examples/test_rvm_conv1d.txt @@ -30,6 +30,9 @@ extern "C" { #endif +#include +#include + #ifndef EXO_WIN_2I32 #define EXO_WIN_2I32 struct exo_win_2i32{ @@ -63,9 +66,6 @@ void exo_conv1d_tile_lt_kw( void *ctxt, const int32_t* data, const int32_t* kern #include #include -#include -#include - // exo_conv1d_tile_lt_kw( // data : i32[4, 16] @DRAM, diff --git a/tests/golden/test_externs/test_expf.txt b/tests/golden/test_externs/test_expf.txt index ebbab2553..bd9d9c7c0 100644 --- a/tests/golden/test_externs/test_expf.txt +++ b/tests/golden/test_externs/test_expf.txt @@ -1,8 +1,5 @@ #include "test.h" -#include -#include - #include // foo( // x : i8[16] @DRAM, @@ -46,6 +43,9 @@ extern "C" { #endif +#include +#include + // foo( // x : i8[16] @DRAM, diff --git a/tests/golden/test_externs/test_fmaxf.txt b/tests/golden/test_externs/test_fmaxf.txt index af16a7798..3197166c4 100644 --- a/tests/golden/test_externs/test_fmaxf.txt +++ b/tests/golden/test_externs/test_fmaxf.txt @@ -1,8 +1,5 @@ #include "test.h" -#include -#include - #include // foo( // x : f32[16] @DRAM, @@ -46,6 +43,9 @@ extern "C" { #endif +#include +#include + // foo( // x : f32[16] @DRAM, diff --git a/tests/golden/test_externs/test_relu.txt b/tests/golden/test_externs/test_relu.txt index f2fd00d91..7cbcd1405 100644 --- a/tests/golden/test_externs/test_relu.txt +++ b/tests/golden/test_externs/test_relu.txt @@ -1,8 +1,5 @@ #include "test.h" -#include -#include - float _relu_float(float x) { if (x > 0.0) return x; else return 0.0; @@ -49,6 +46,9 @@ extern "C" { #endif +#include +#include + // foo( // x : f32[16] @DRAM diff --git a/tests/golden/test_externs/test_relu2.txt b/tests/golden/test_externs/test_relu2.txt index 8d5174c56..565f67190 100644 --- a/tests/golden/test_externs/test_relu2.txt +++ b/tests/golden/test_externs/test_relu2.txt @@ -1,8 +1,5 @@ #include "test.h" -#include -#include - float _relu_float(float x) { if (x > 0.0) return x; else return 0.0; @@ -49,6 +46,9 @@ extern "C" { #endif +#include +#include + // foo( // x : f32[16] @DRAM diff --git a/tests/golden/test_externs/test_relu3.txt b/tests/golden/test_externs/test_relu3.txt index d1b294fc3..fa7523563 100644 --- a/tests/golden/test_externs/test_relu3.txt +++ b/tests/golden/test_externs/test_relu3.txt @@ -1,8 +1,5 @@ #include "test.h" -#include -#include - float _relu_float(float x) { if (x > 0.0) return x; else return 0.0; @@ -51,6 +48,9 @@ extern "C" { #endif +#include +#include + // foo( // x : f32[16] @DRAM, diff --git a/tests/golden/test_externs/test_relu4.txt b/tests/golden/test_externs/test_relu4.txt index e1d141c51..4fce22a44 100644 --- a/tests/golden/test_externs/test_relu4.txt +++ b/tests/golden/test_externs/test_relu4.txt @@ -1,8 +1,5 @@ #include "test.h" -#include -#include - int8_t _relu_int8_t(int8_t x) { if (x > 0.0) return x; else return 0.0; @@ -49,6 +46,9 @@ extern "C" { #endif +#include +#include + // foo( // x : i8[16] @DRAM diff --git a/tests/golden/test_externs/test_select.txt b/tests/golden/test_externs/test_select.txt index fa71ccbad..7f44596a7 100644 --- a/tests/golden/test_externs/test_select.txt +++ b/tests/golden/test_externs/test_select.txt @@ -1,8 +1,5 @@ #include "test.h" -#include -#include - int8_t _select_int8_t(int8_t x,int8_t v,int8_t y,int8_t z) { if (x < v) return y; else return z; @@ -51,6 +48,9 @@ extern "C" { #endif +#include +#include + // foo( // x : i8[16] @DRAM, diff --git a/tests/golden/test_externs/test_sigmoid.txt b/tests/golden/test_externs/test_sigmoid.txt index bc202a82b..983f21c95 100644 --- a/tests/golden/test_externs/test_sigmoid.txt +++ b/tests/golden/test_externs/test_sigmoid.txt @@ -1,8 +1,5 @@ #include "test.h" -#include -#include - #include float sigmoid(float x) { @@ -51,6 +48,9 @@ extern "C" { #endif +#include +#include + // foo( // x : f32[16] @DRAM, diff --git a/tests/golden/test_externs/test_sin.txt b/tests/golden/test_externs/test_sin.txt index 3c6784c39..519375482 100644 --- a/tests/golden/test_externs/test_sin.txt +++ b/tests/golden/test_externs/test_sin.txt @@ -1,8 +1,5 @@ #include "test.h" -#include -#include - #include // foo( // x : i8[16] @DRAM @@ -45,6 +42,9 @@ extern "C" { #endif +#include +#include + // foo( // x : i8[16] @DRAM diff --git a/tests/golden/test_externs/test_sqrt.txt b/tests/golden/test_externs/test_sqrt.txt index d37ce59b5..7a3b94ab2 100644 --- a/tests/golden/test_externs/test_sqrt.txt +++ b/tests/golden/test_externs/test_sqrt.txt @@ -1,8 +1,5 @@ #include "test.h" -#include -#include - #include // foo( // x : f32[16] @DRAM, @@ -46,6 +43,9 @@ extern "C" { #endif +#include +#include + // foo( // x : f32[16] @DRAM, diff --git a/tests/golden/test_metaprogramming/test_capture_nested_quote.txt b/tests/golden/test_metaprogramming/test_capture_nested_quote.txt index 8e5082820..fa69812cc 100644 --- a/tests/golden/test_metaprogramming/test_capture_nested_quote.txt +++ b/tests/golden/test_metaprogramming/test_capture_nested_quote.txt @@ -6,9 +6,6 @@ def foo(a: i32 @ DRAM): C: #include "test.h" -#include -#include - // foo( // a : i32 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_captured_closure.txt b/tests/golden/test_metaprogramming/test_captured_closure.txt index 569653d3d..a0b36fced 100644 --- a/tests/golden/test_metaprogramming/test_captured_closure.txt +++ b/tests/golden/test_metaprogramming/test_captured_closure.txt @@ -13,9 +13,6 @@ def bar(a: i32 @ DRAM): C: #include "test.h" -#include -#include - // bar( // a : i32 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_conditional.txt b/tests/golden/test_metaprogramming/test_conditional.txt index 7e3473e58..c81651627 100644 --- a/tests/golden/test_metaprogramming/test_conditional.txt +++ b/tests/golden/test_metaprogramming/test_conditional.txt @@ -8,9 +8,6 @@ def bar2(a: i8 @ DRAM): C: #include "test.h" -#include -#include - // bar1( // a : i8 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_constant_lifting.txt b/tests/golden/test_metaprogramming/test_constant_lifting.txt index 5ac001ad4..0f0570006 100644 --- a/tests/golden/test_metaprogramming/test_constant_lifting.txt +++ b/tests/golden/test_metaprogramming/test_constant_lifting.txt @@ -4,9 +4,6 @@ def foo(a: f64 @ DRAM): C: #include "test.h" -#include -#include - // foo( // a : f64 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_eval_expr_in_mem.txt b/tests/golden/test_metaprogramming/test_eval_expr_in_mem.txt index 29bc17829..ed9d5d452 100644 --- a/tests/golden/test_metaprogramming/test_eval_expr_in_mem.txt +++ b/tests/golden/test_metaprogramming/test_eval_expr_in_mem.txt @@ -4,9 +4,6 @@ def foo(a: f32 @ DRAM): C: #include "test.h" -#include -#include - // foo( // a : f32 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_implicit_lhs_unquote.txt b/tests/golden/test_metaprogramming/test_implicit_lhs_unquote.txt index c8720ccf4..6a02f4bac 100644 --- a/tests/golden/test_metaprogramming/test_implicit_lhs_unquote.txt +++ b/tests/golden/test_metaprogramming/test_implicit_lhs_unquote.txt @@ -5,9 +5,6 @@ def foo(a: i32 @ DRAM, b: i32 @ DRAM): C: #include "test.h" -#include -#include - // foo( // a : i32 @DRAM, // b : i32 @DRAM diff --git a/tests/golden/test_metaprogramming/test_local_externs.txt b/tests/golden/test_metaprogramming/test_local_externs.txt index 504175e70..87e832f8f 100644 --- a/tests/golden/test_metaprogramming/test_local_externs.txt +++ b/tests/golden/test_metaprogramming/test_local_externs.txt @@ -4,9 +4,6 @@ def foo(a: f64 @ DRAM): C: #include "test.h" -#include -#include - #include // foo( // a : f64 @DRAM diff --git a/tests/golden/test_metaprogramming/test_proc_shadowing.txt b/tests/golden/test_metaprogramming/test_proc_shadowing.txt index 5a3d36701..8c2c50e94 100644 --- a/tests/golden/test_metaprogramming/test_proc_shadowing.txt +++ b/tests/golden/test_metaprogramming/test_proc_shadowing.txt @@ -4,9 +4,6 @@ def foo(a: f32 @ DRAM): C: #include "test.h" -#include -#include - // sin( // a : f32 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_quote_complex_expr.txt b/tests/golden/test_metaprogramming/test_quote_complex_expr.txt index b111df4f9..69afd70f2 100644 --- a/tests/golden/test_metaprogramming/test_quote_complex_expr.txt +++ b/tests/golden/test_metaprogramming/test_quote_complex_expr.txt @@ -4,9 +4,6 @@ def foo(a: i32 @ DRAM): C: #include "test.h" -#include -#include - // foo( // a : i32 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_quote_elision.txt b/tests/golden/test_metaprogramming/test_quote_elision.txt index a22821c70..fd2a69dd8 100644 --- a/tests/golden/test_metaprogramming/test_quote_elision.txt +++ b/tests/golden/test_metaprogramming/test_quote_elision.txt @@ -4,9 +4,6 @@ def foo(a: i32 @ DRAM, b: i32 @ DRAM): C: #include "test.h" -#include -#include - // foo( // a : i32 @DRAM, // b : i32 @DRAM diff --git a/tests/golden/test_metaprogramming/test_scope_collision1.txt b/tests/golden/test_metaprogramming/test_scope_collision1.txt index bc9b67584..2ade40bc6 100644 --- a/tests/golden/test_metaprogramming/test_scope_collision1.txt +++ b/tests/golden/test_metaprogramming/test_scope_collision1.txt @@ -6,9 +6,6 @@ def foo(a: i32 @ DRAM): C: #include "test.h" -#include -#include - // foo( // a : i32 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_scope_collision2.txt b/tests/golden/test_metaprogramming/test_scope_collision2.txt index fe7faf52a..285a15c6d 100644 --- a/tests/golden/test_metaprogramming/test_scope_collision2.txt +++ b/tests/golden/test_metaprogramming/test_scope_collision2.txt @@ -4,9 +4,6 @@ def foo(a: i32 @ DRAM, b: i32 @ DRAM): C: #include "test.h" -#include -#include - // foo( // a : i32 @DRAM, // b : i32 @DRAM diff --git a/tests/golden/test_metaprogramming/test_scope_nesting.txt b/tests/golden/test_metaprogramming/test_scope_nesting.txt index 0ae39ca18..43eae7c6e 100644 --- a/tests/golden/test_metaprogramming/test_scope_nesting.txt +++ b/tests/golden/test_metaprogramming/test_scope_nesting.txt @@ -4,9 +4,6 @@ def foo(a: i8 @ DRAM, b: i8 @ DRAM): C: #include "test.h" -#include -#include - // foo( // a : i8 @DRAM, // b : i8 @DRAM diff --git a/tests/golden/test_metaprogramming/test_scoping.txt b/tests/golden/test_metaprogramming/test_scoping.txt index ddd9e9f3f..73672101f 100644 --- a/tests/golden/test_metaprogramming/test_scoping.txt +++ b/tests/golden/test_metaprogramming/test_scoping.txt @@ -4,9 +4,6 @@ def foo(a: i8 @ DRAM): C: #include "test.h" -#include -#include - // foo( // a : i8 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_statement_assignment.txt b/tests/golden/test_metaprogramming/test_statement_assignment.txt index a8ea5b1a1..4f9e62a44 100644 --- a/tests/golden/test_metaprogramming/test_statement_assignment.txt +++ b/tests/golden/test_metaprogramming/test_statement_assignment.txt @@ -7,9 +7,6 @@ def foo(a: i32 @ DRAM): C: #include "test.h" -#include -#include - // foo( // a : i32 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_type_params.txt b/tests/golden/test_metaprogramming/test_type_params.txt index 98c6282a1..fff9b314b 100644 --- a/tests/golden/test_metaprogramming/test_type_params.txt +++ b/tests/golden/test_metaprogramming/test_type_params.txt @@ -16,9 +16,6 @@ def bar2(a: f64 @ DRAM, b: f64 @ DRAM): C: #include "test.h" -#include -#include - // bar1( // a : i32 @DRAM, // b : i8 @DRAM diff --git a/tests/golden/test_metaprogramming/test_type_quote_elision.txt b/tests/golden/test_metaprogramming/test_type_quote_elision.txt index d9173f3de..4690eb46e 100644 --- a/tests/golden/test_metaprogramming/test_type_quote_elision.txt +++ b/tests/golden/test_metaprogramming/test_type_quote_elision.txt @@ -5,9 +5,6 @@ def foo(a: i8 @ DRAM, x: i8[2] @ DRAM): C: #include "test.h" -#include -#include - // foo( // a : i8 @DRAM, // x : i8[2] @DRAM diff --git a/tests/golden/test_metaprogramming/test_unary_ops.txt b/tests/golden/test_metaprogramming/test_unary_ops.txt index 028ac6f30..d020f8b64 100644 --- a/tests/golden/test_metaprogramming/test_unary_ops.txt +++ b/tests/golden/test_metaprogramming/test_unary_ops.txt @@ -4,9 +4,6 @@ def foo(a: i32 @ DRAM): C: #include "test.h" -#include -#include - // foo( // a : i32 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_unquote_elision.txt b/tests/golden/test_metaprogramming/test_unquote_elision.txt index 710799136..6cbefcbc5 100644 --- a/tests/golden/test_metaprogramming/test_unquote_elision.txt +++ b/tests/golden/test_metaprogramming/test_unquote_elision.txt @@ -4,9 +4,6 @@ def foo(a: i32 @ DRAM): C: #include "test.h" -#include -#include - // foo( // a : i32 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_unquote_in_slice.txt b/tests/golden/test_metaprogramming/test_unquote_in_slice.txt index de0fc0e9a..16be9ee33 100644 --- a/tests/golden/test_metaprogramming/test_unquote_in_slice.txt +++ b/tests/golden/test_metaprogramming/test_unquote_in_slice.txt @@ -7,9 +7,6 @@ def bar(a: i8[10, 10] @ DRAM): C: #include "test.h" -#include -#include - // bar( // a : i8[10, 10] @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt b/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt index 49abf3067..6eb0afc91 100644 --- a/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt +++ b/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt @@ -8,9 +8,6 @@ def bar(a: i8[10, 10, 10] @ DRAM): C: #include "test.h" -#include -#include - // bar( // a : i8[10, 10, 10] @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt b/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt index ea4f97988..29766fe8e 100644 --- a/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt +++ b/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt @@ -11,9 +11,6 @@ def bar(a: i8[10, 10] @ DRAM): C: #include "test.h" -#include -#include - // bar( // a : i8[10, 10] @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_unrolling.txt b/tests/golden/test_metaprogramming/test_unrolling.txt index 136c770c1..8b830a4dc 100644 --- a/tests/golden/test_metaprogramming/test_unrolling.txt +++ b/tests/golden/test_metaprogramming/test_unrolling.txt @@ -15,9 +15,6 @@ def foo(a: i8 @ DRAM): C: #include "test.h" -#include -#include - // foo( // a : i8 @DRAM // ) diff --git a/tests/golden/test_parallel/test_pragma_parallel_loop.txt b/tests/golden/test_parallel/test_pragma_parallel_loop.txt index ce534bf36..d6aaeb8b1 100644 --- a/tests/golden/test_parallel/test_pragma_parallel_loop.txt +++ b/tests/golden/test_parallel/test_pragma_parallel_loop.txt @@ -1,8 +1,5 @@ #include "test.h" -#include -#include - // foo( // x : i8[10] @DRAM // ) diff --git a/tests/golden/test_precision/test_good_prec2.txt b/tests/golden/test_precision/test_good_prec2.txt index 6da2b1674..a63b8cd8e 100644 --- a/tests/golden/test_precision/test_good_prec2.txt +++ b/tests/golden/test_precision/test_good_prec2.txt @@ -21,6 +21,9 @@ #endif +#include +#include + // hoge( // n : size, @@ -30,9 +33,6 @@ void hoge( void *ctxt, int_fast32_t n, const float* x, const float* y ); -#include -#include - // dot( // m : size, // x : f32[m] @DRAM, diff --git a/tests/golden/test_precision/test_good_ui8_prec.txt b/tests/golden/test_precision/test_good_ui8_prec.txt index 50601c276..8317009fd 100644 --- a/tests/golden/test_precision/test_good_ui8_prec.txt +++ b/tests/golden/test_precision/test_good_ui8_prec.txt @@ -21,6 +21,9 @@ #endif +#include +#include + // hoge( // n : size, @@ -30,9 +33,6 @@ void hoge( void *ctxt, int_fast32_t n, uint8_t* x, const uint8_t* y ); -#include -#include - // hoge( // n : size, // x : ui8[n] @DRAM, diff --git a/tests/golden/test_schedules/test_expand_dim3.txt b/tests/golden/test_schedules/test_expand_dim3.txt index b0719c3b7..2d0bde341 100644 --- a/tests/golden/test_schedules/test_expand_dim3.txt +++ b/tests/golden/test_schedules/test_expand_dim3.txt @@ -21,6 +21,9 @@ #endif +#include +#include + // foo( // n : size, @@ -30,9 +33,6 @@ void foo( void *ctxt, int_fast32_t n, int_fast32_t m, int8_t* x ); -#include -#include - // foo( // n : size, // m : size, diff --git a/tests/golden/test_window/test_normalize.txt b/tests/golden/test_window/test_normalize.txt index 40bdba83a..e911c9451 100644 --- a/tests/golden/test_window/test_normalize.txt +++ b/tests/golden/test_window/test_normalize.txt @@ -21,6 +21,9 @@ #endif +#include +#include + #ifndef EXO_WIN_1F32C #define EXO_WIN_1F32C struct exo_win_1f32c{ @@ -37,9 +40,6 @@ struct exo_win_1f32c{ void proj( void *ctxt, int_fast32_t n, int_fast32_t m, const float* x, const float* y ); -#include -#include - // dot( // m : size, // x : [f32][m] @DRAM, diff --git a/tests/golden/test_window/test_stride_assert.txt b/tests/golden/test_window/test_stride_assert.txt index 712be9ab0..9afa2bad4 100644 --- a/tests/golden/test_window/test_stride_assert.txt +++ b/tests/golden/test_window/test_stride_assert.txt @@ -21,6 +21,9 @@ #endif +#include +#include + #ifndef EXO_WIN_2I8 #define EXO_WIN_2I8 struct exo_win_2i8{ @@ -44,9 +47,6 @@ struct exo_win_2i8c{ void stride_assert( void *ctxt, int_fast32_t n, int_fast32_t m, struct exo_win_2i8c src, struct exo_win_2i8 dst ); -#include -#include - // stride_assert( // n : size, // m : size, diff --git a/tests/golden/test_window/test_window.txt b/tests/golden/test_window/test_window.txt index 30f66f550..eecb053e1 100644 --- a/tests/golden/test_window/test_window.txt +++ b/tests/golden/test_window/test_window.txt @@ -21,6 +21,9 @@ #endif +#include +#include + #ifndef EXO_WIN_2I8 #define EXO_WIN_2I8 struct exo_win_2i8{ @@ -44,9 +47,6 @@ struct exo_win_2i8c{ void window( void *ctxt, int_fast32_t n, int_fast32_t m, struct exo_win_2i8c src, struct exo_win_2i8 dst ); -#include -#include - // window( // n : size, // m : size, diff --git a/tests/golden/test_window/test_window_stmt.txt b/tests/golden/test_window/test_window_stmt.txt index d27799f06..cadd58e0b 100644 --- a/tests/golden/test_window/test_window_stmt.txt +++ b/tests/golden/test_window/test_window_stmt.txt @@ -21,6 +21,9 @@ #endif +#include +#include + #ifndef EXO_WIN_1F32C #define EXO_WIN_1F32C struct exo_win_1f32c{ @@ -36,9 +39,6 @@ struct exo_win_1f32c{ void window_stmt( void *ctxt, int_fast32_t n, int_fast32_t m, const float* x ); -#include -#include - // window_stmt( // n : size, // m : size, diff --git a/tests/test_codegen.py b/tests/test_codegen.py index ebdc17118..cc3e790b7 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -685,7 +685,6 @@ def bar(n: size, dst: f32[n], src: f32[n]): optimized_bar = replace(bar, bar.body()[0], memcpy) bar_c, bar_h = compile_procs_to_strings([optimized_bar], "bar.h") - assert f"{bar_c}\n{bar_h}" == golden fn = compiler.compile(optimized_bar) diff --git a/tests/test_special_window.py b/tests/test_special_window.py new file mode 100644 index 000000000..17a946822 --- /dev/null +++ b/tests/test_special_window.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import os +from pathlib import Path + +from exo import ( + proc, + instr, + Procedure, + DRAM, + compile_procs_to_strings, + MemWin, + Memory, + WindowStructCtx, + SpecialWindow, + SpecialWindowFromMemoryCtx, + memwin_template, +) +from exo.libs.memories import MDRAM, MemGenError, StaticMemory, DRAM_STACK +from exo.libs.externs import * +from exo.stdlib.scheduling import * + + +class TestCudaGmem(Memory): + @classmethod + def window_definition(cls, ctx: WindowStructCtx): + return ctx.generate_default("TestCudaGmem") + + +@memwin_template +def TestTensorMap(swizzle, *box): + assert len(box) == 2 + smem_outer, smem_inner = box + assert isinstance(smem_outer, int) + assert isinstance(smem_inner, int) + + if swizzle == 0: + cu_swizzle = "CU_TENSOR_MAP_SWIZZLE_NONE" + else: + assert swizzle in (32, 64, 128) + cu_swizzle = f"CU_TENSOR_MAP_SWIZZLE_{swizzle}B" + + class Impl(SpecialWindow): + @classmethod + def global_(cls): + return f"""\ +#include +#include +#include +#include """ + + @classmethod + def window_definition(cls, ctx: WindowStructCtx): + assert ctx.type_shorthand() == "f32" + cu_ctype_enum = "CU_TENSOR_MAP_DATA_TYPE_FLOAT32" + sname = ctx.struct_name("CUtensorMap", cls.memwin_template_parameters) + s_def = f"""\ +struct {sname} {{ + unsigned outer_offset, inner_offset; +}}; + +struct {sname}_strides {{ + unsigned outer, inner; +}}; + +static inline CUtensorMap {sname}_encode_tensor_map( + const void* globalAddress, // window dataptr + struct {sname}_strides gmem_stride, // window layout + unsigned gmem_outer, unsigned gmem_inner) +{{ + assert(gmem_stride.inner == 1); + + CUtensorMap tensorMap; + const CUtensorMapSwizzle swizzle = {cu_swizzle}; + const uint32_t tensorRank = 2; + const cuuint64_t globalDim[2] = {{gmem_inner, gmem_outer}}; + const cuuint64_t globalStrides[1] = {{sizeof({ctx.ctype()}) * gmem_stride.outer}}; + const cuuint32_t boxDim[2] = {{ {smem_inner}, {smem_outer} }}; + const cuuint32_t elementStrides[2] = {{1, 1}}; + const CUtensorMapInterleave interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + const CUtensorMapL2promotion l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + const CUtensorMapFloatOOBfill oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + + const CUresult result = cuTensorMapEncodeTiled( + &tensorMap, + {cu_ctype_enum}, + tensorRank, + (void*)globalAddress, + globalDim, + globalStrides, + boxDim, + elementStrides, + interleave, + swizzle, + l2Promotion, + oobFill); + if (result != 0) {{ + fprintf(stderr, "cuTensorMapEncodeTiled: %i {{%u, %u}}\\n", (int)result, {smem_inner}, {smem_outer}); + assert(0); + }} + return tensorMap; +}} +""" + return "CUtensorMap", s_def + + @classmethod + def separate_dataptr(cls): + return True + + @classmethod + def window(cls, basetyp, in_expr, indices, strides, srcinfo): + assert len(indices) == 2 + dataptr, in_layout = in_expr + out_layout = f"{{ {in_layout}.outer_offset + {indices[0]}, {in_layout}.inner_offset + {indices[1]} }}" + return dataptr, out_layout + + @classmethod + def can_read(cls): + return False + + @classmethod + def source_memory_type(cls): + return TestCudaGmem + + @classmethod + def from_memory(cls, ctx: SpecialWindowFromMemoryCtx): + sname = ctx.dst_struct_name() + shape0, shape1 = ctx.shape_strs() + clayout = f"(struct {sname}_strides){ctx.src_layout()}" + d_def = f"{sname}_encode_tensor_map(&{ctx.src_data()}, {clayout}, {shape0}, {shape1})" + + # Offsets can be 0'd + w_def = "{}" + return d_def, w_def + + return Impl + + +def test_tensor_map(): + @proc + def test_proc( + tensor: f32[1024, 2048] @ TestCudaGmem, + input_tensor_map: [f32][128, 128] @ TestTensorMap(0, 128, 128), + ): + basic_window = tensor[14, :] + tensor_map_0 = tensor[:, :] @ TestTensorMap(0, 128, 128) + tensor_map_1 = tensor_map_0[14:, :] + tensor_map_C = tensor[14:, :] @ TestTensorMap(128, 196, 128) + tensor_map_D = tensor_map_C[10:, 200:] + + c = test_proc.find("tensor_map_C = _") + assert c.special_window() is TestTensorMap(128, 196, 128) + assert c.special_window() is not TestTensorMap(0, 128, 128) + + cc, hh = compile_procs_to_strings([test_proc], "test.h") + + # This is just a placeholder test for now + if False: + HOME = os.environ["HOME"] + open(f"{HOME}/junk/test.h", "w").write(hh) + open(f"{HOME}/junk/test.c", "w").write(cc) + + # fmt: off + + # TestTensorMap(0, 128, 128) defs should have ended up in the header file. + # There should be no const suffix. + assert "struct exo_win_2f32_CUtensorMap_0_128_128 {" in hh + assert "inline CUtensorMap exo_win_2f32_CUtensorMap_0_128_128_encode_tensor_map" in hh + + # test_proc definition should have separate tensormap, layout + # inputs for input_tensor_map. + assert "CUtensorMap exo_data_input_tensor_map, struct exo_win_2f32_CUtensorMap_0_128_128 input_tensor_map" in hh + + # TestTensorMap(128, 196, 128) defs should have ended up in the C file + assert "struct exo_win_2f32_CUtensorMap_128_196_128 {" in cc + assert "inline CUtensorMap exo_win_2f32_CUtensorMap_128_196_128_encode_tensor_map" in cc + + # Expected window code + assert "struct exo_win_1f32c_TestCudaGmem basic_window = (struct exo_win_1f32c_TestCudaGmem){ &tensor[(14) * (2048)], { 1 } };" in cc + assert "CUtensorMap exo_data_tensor_map_0 = exo_win_2f32_CUtensorMap_0_128_128_encode_tensor_map(&tensor[0], (struct exo_win_2f32_CUtensorMap_0_128_128_strides){ 2048, 1 }, 1024, 2048);" in cc + assert "struct exo_win_2f32_CUtensorMap_0_128_128 tensor_map_0 = {};" in cc + assert "CUtensorMap exo_data_tensor_map_1 = exo_data_tensor_map_0;" in cc + assert "struct exo_win_2f32_CUtensorMap_0_128_128 tensor_map_1 = (struct exo_win_2f32_CUtensorMap_0_128_128) { tensor_map_0.outer_offset + 14, tensor_map_0.inner_offset + 0 };" in cc + assert "CUtensorMap exo_data_tensor_map_C = exo_win_2f32_CUtensorMap_128_196_128_encode_tensor_map(&tensor[(14) * (2048)], (struct exo_win_2f32_CUtensorMap_128_196_128_strides){ 2048, 1 }, 1010, 2048);" in cc + assert "struct exo_win_2f32_CUtensorMap_128_196_128 tensor_map_C = {};" in cc + assert "CUtensorMap exo_data_tensor_map_D = exo_data_tensor_map_C;" in cc + assert "struct exo_win_2f32_CUtensorMap_128_196_128 tensor_map_D = (struct exo_win_2f32_CUtensorMap_128_196_128) { tensor_map_C.outer_offset + 10, tensor_map_C.inner_offset + 200 };" in cc + + # fmt: on