diff --git a/sealir-tutorials/ch04_1_typeinfer_ifelse.py b/sealir-tutorials/ch04_1_typeinfer_ifelse.py index bb979cc..1fa55da 100644 --- a/sealir-tutorials/ch04_1_typeinfer_ifelse.py +++ b/sealir-tutorials/ch04_1_typeinfer_ifelse.py @@ -487,75 +487,96 @@ class CompilationError(Exception): pass -def compiler_pipeline( - fn, - argtypes, - *, - verbose=False, - ruleset: Ruleset, - converter_class=EGraphToRVSDG, - cost_model=None, - backend, -): +@dataclass +class Compiler: + converter_class: Backend + backend: int + cost_model: CostModel + verbose: EGraphToRVSDG + + def run_frontend(self, fn): + rvsdg_expr, dbginfo = frontend(fn) + return rvsdg_expr, dbginfo + + def run_middle_end(self, rvsdg_expr, ruleset): + + # Middle end + def define_egraph( + egraph: EGraph, + func: SExpr, + ): + # Define graph root that points to the function + root = GraphRoot(func) + egraph.let("root", root) + + # Define the empty root node for the error messages + errors = ErrorMsg.root() + egraph.let("errors", errors) + + # Run all the rules until saturation + egraph.run(ruleset.saturate()) + + if self.verbose and IN_NOTEBOOK: + # For inspecting the egraph + egraph.display(graphviz=True) + print(egraph.extract(root)) + # Use egglog's default extractor to get the error messages + errmsgs = map( + lambda x: x.eval(), egraph.extract_multiple(errors, n=10) + ) + errmsgs_filtered = [ + get_error_message((meth, args)) + for meth, args in errmsgs + if meth != "root" + ] + if errmsgs_filtered: + # Raise CompilationError if there are compiler errors + raise CompilationError("\n".join(errmsgs_filtered)) - rvsdg_expr, dbginfo = frontend(fn) + try: + cost, extracted = middle_end( + rvsdg_expr, + define_egraph, + converter_class=self.converter_class, + cost_model=self.cost_model, + ) + except ExtractionError as e: + raise CompilationError("extraction failed") from e + + return cost, extracted - print("Before EGraph".center(80, "=")) - print(format_rvsdg(rvsdg_expr)) + def run_backend(self, extracted, argtypes): + return self.backend.lower(extracted, argtypes) - # Middle end - def define_egraph( - egraph: EGraph, - func: SExpr, - ): - # Define graph root that points to the function - root = GraphRoot(func) - egraph.let("root", root) + def lower_py_fn(self, fn, argtypes, ruleset): + + rvsdg_expr, dbginfo = self.run_frontend(fn) - # Define the empty root node for the error messages - errors = ErrorMsg.root() - egraph.let("errors", errors) + print("Before EGraph".center(80, "=")) + print(format_rvsdg(rvsdg_expr)) - # Run all the rules until saturation - egraph.run(ruleset.saturate()) + cost, extracted = self.run_middle_end(rvsdg_expr, ruleset) - if verbose and IN_NOTEBOOK: - # For inspecting the egraph - egraph.display(graphviz=True) - print(egraph.extract(root)) - # Use egglog's default extractor to get the error messages - errmsgs = map( - lambda x: x.eval(), egraph.extract_multiple(errors, n=10) - ) - errmsgs_filtered = [ - get_error_message((meth, args)) - for meth, args in errmsgs - if meth != "root" - ] - if errmsgs_filtered: - # Raise CompilationError if there are compiler errors - raise CompilationError("\n".join(errmsgs_filtered)) + print("Extracted from EGraph".center(80, "=")) + print("cost =", cost) + print(format_rvsdg(extracted)) - try: - cost, extracted = middle_end( - rvsdg_expr, - define_egraph, - converter_class=converter_class, - cost_model=cost_model, - ) - except ExtractionError as e: - raise CompilationError("extraction failed") from e + module = self.run_backend(extracted, argtypes) + + if self.verbose: + print("LLVM module".center(80, "=")) + print(module) + + return module, extracted - print("Extracted from EGraph".center(80, "=")) - print("cost =", cost) - print(format_rvsdg(extracted)) + def run_backend_passes(self, module): + self.backend.run_passes(module) - llmod = backend.lower(extracted, argtypes) - if verbose: - print("LLVM module".center(80, "=")) - print(llmod) - return backend.jit_compile(llmod, extracted) + def compile_module(self, module, egraph_node, func_name="func"): + return self.backend.jit_compile(module, egraph_node, func_name) + def compile_module_(self, llmod, input_types, output_types, function_name="func", exec_engine=None, **execution_engine_params): + return self.backend.jit_compile_(llmod, input_types, output_types, function_name, exec_engine, **execution_engine_params) # + [markdown] jp-MarkdownHeadingCollapsed=true # ### Define EGraph functions for new operations: @@ -1237,7 +1258,7 @@ def lower_expr(self, expr: SExpr, state: LowerStates): raise NotImplementedError(expr) - def jit_compile(self, llmod: ir.Module, func_node: rg.Func): + def jit_compile(self, llmod: ir.Module, func_node: rg.Func, func_name): sym = func_node.fname # Create JIT lljit = llvm.create_lljit_compiler() @@ -1265,6 +1286,9 @@ def get_ctype(self, lltype: ir.Type): case ir.DoubleType(): return ctypes.c_double raise NotImplementedError(lltype) + + def run_passes(self, module, passes): + pass # - @@ -1317,20 +1341,21 @@ def example_1(a, b): return z + a +compiler = Compiler(ExtendEGraphToRVSDG, Backend(), MyCostModel(), verbose=True) + if __name__ == "__main__": - jt = compiler_pipeline( + llvm_module, func_egraph = compiler.lower_py_fn( example_1, argtypes=(Int64, Int64), ruleset=(base_ruleset | setup_argtypes(TypeInt64, TypeInt64)), - verbose=True, - converter_class=ExtendEGraphToRVSDG, - cost_model=MyCostModel(), - backend=Backend(), ) + + jit_func = compiler.compile_module(llvm_module, func_egraph) + args = (10, 33) - run_test(example_1, jt, args, verbose=True) + run_test(example_1, jit_func, args, verbose=True) args = (7, 3) - run_test(example_1, jt, args, verbose=True) + run_test(example_1, jit_func, args, verbose=True) # ## Example 2: add `float()` @@ -1372,7 +1397,7 @@ def ruleset_type_infer_float( if __name__ == "__main__": - jt = compiler_pipeline( + llvm_module, func_egraph = compiler.lower_py_fn( example_2, argtypes=(Int64, Int64), ruleset=( @@ -1380,15 +1405,12 @@ def ruleset_type_infer_float( | setup_argtypes(TypeInt64, TypeInt64) | ruleset_type_infer_float # < --- added for float() ), - verbose=True, - converter_class=ExtendEGraphToRVSDG, - cost_model=MyCostModel(), - backend=Backend(), ) + jit_func = compiler.compile_module(llvm_module, func_egraph) args = (10, 33) - run_test(example_2, jt, args, verbose=True) + run_test(example_2, jit_func, args, verbose=True) args = (7, 3) - run_test(example_2, jt, args, verbose=True) + run_test(example_2, jit_func, args, verbose=True) # ## Example 3: unify mismatching type # @@ -1417,7 +1439,7 @@ def ruleset_failed_to_unify(ty: Type): if __name__ == "__main__": try: - compiler_pipeline( + llvm_module, func_egraph = compiler.lower_py_fn( example_3, argtypes=(Int64, Int64), ruleset=( @@ -1425,11 +1447,7 @@ def ruleset_failed_to_unify(ty: Type): | setup_argtypes(TypeInt64, TypeInt64) | ruleset_type_infer_float | ruleset_failed_to_unify - ), - verbose=True, - converter_class=ExtendEGraphToRVSDG, - cost_model=MyCostModel(), - backend=Backend(), + ) ) except CompilationError as e: # Compilation failed because the return type cannot be determined. @@ -1476,7 +1494,7 @@ def ruleset_type_infer_failure_report( if __name__ == "__main__": try: - compiler_pipeline( + llvm_module, func_egraph = compiler.lower_py_fn( example_3, argtypes=(Int64, Int64), ruleset=( @@ -1486,10 +1504,6 @@ def ruleset_type_infer_failure_report( | ruleset_failed_to_unify | ruleset_type_infer_failure_report ), - verbose=True, - converter_class=ExtendEGraphToRVSDG, - cost_model=MyCostModel(), - backend=Backend(), ) except CompilationError as e: diff --git a/sealir-tutorials/ch04_2_typeinfer_loops.py b/sealir-tutorials/ch04_2_typeinfer_loops.py index 7d0142f..8a489eb 100644 --- a/sealir-tutorials/ch04_2_typeinfer_loops.py +++ b/sealir-tutorials/ch04_2_typeinfer_loops.py @@ -66,7 +66,7 @@ ) from ch04_1_typeinfer_ifelse import base_ruleset as _ch4_1_base_ruleset from ch04_1_typeinfer_ifelse import ( - compiler_pipeline, + Compiler, ruleset_failed_to_unify, ruleset_type_infer_failure_report, ruleset_type_infer_float, @@ -289,17 +289,16 @@ def example_1(init, n): return c +compiler = Compiler(ExtendEGraphToRVSDG, Backend(), MyCostModel(), True) + if __name__ == "__main__": - jt = compiler_pipeline( + llvm_module, func_egraph = compiler.lower_py_fn( example_1, argtypes=(Int64, Int64), ruleset=base_ruleset | setup_argtypes(TypeInt64, TypeInt64), - verbose=True, - converter_class=ExtendEGraphToRVSDG, - cost_model=MyCostModel(), - backend=Backend(), ) - run_test(example_1, jt, (10, 7), verbose=True) + jit_func = compiler.compile_module(llvm_module, func_egraph) + run_test(example_1, jit_func, (10, 7), verbose=True) # ## Example 2: Nested Loop example @@ -318,13 +317,10 @@ def example_2(init, n): if __name__ == "__main__": - jt = compiler_pipeline( + llvm_module, func_egraph = compiler.lower_py_fn( example_2, argtypes=(Int64, Int64), ruleset=base_ruleset | setup_argtypes(TypeInt64, TypeInt64), - verbose=True, - converter_class=ExtendEGraphToRVSDG, - cost_model=MyCostModel(), - backend=Backend(), ) - run_test(example_2, jt, (10, 7), verbose=True) + jit_func = compiler.compile_module(llvm_module, func_egraph) + run_test(example_2, jit_func, (10, 7), verbose=True) diff --git a/sealir-tutorials/ch05_typeinfer_array.py b/sealir-tutorials/ch05_typeinfer_array.py index 7278711..85d8e87 100644 --- a/sealir-tutorials/ch05_typeinfer_array.py +++ b/sealir-tutorials/ch05_typeinfer_array.py @@ -74,7 +74,7 @@ TypeInt64, TypeVar, base_ruleset, - compiler_pipeline, + Compiler, setup_argtypes, ) from utils import IN_NOTEBOOK @@ -402,9 +402,11 @@ class CtypeInt64Array1D(ctypes.Structure): "array_int64_1d", shape=("n",), dtype=TypeInt64, layout="c" ) +compiler = Compiler(ExtendEGraphToRVSDG, Backend(), MyCostModel(), True) + if __name__ == "__main__": # compile - jt = compiler_pipeline( + llvm_module, func_egraph = compiler.lower_py_fn( example_1, argtypes=(array_1d_symbolic, Int64), ruleset=( @@ -413,11 +415,8 @@ class CtypeInt64Array1D(ctypes.Structure): | ruleset(*array_infos) | ruleset_typeinfer_array_getitem ), - verbose=True, - converter_class=ExtendEGraphToRVSDG, - cost_model=MyCostModel(), - backend=Backend(), ) + jit_func = compiler.compile_module(llvm_module, func_egraph) # create array ary = np.arange(10, dtype=np.int64) # prepare array for passing to C-API @@ -425,7 +424,7 @@ class CtypeInt64Array1D(ctypes.Structure): param_ary.ptr = ary.ctypes.data param_ary.shape[0] = ary.shape[0] # call the compiled function - got = jt(ctypes.byref(param_ary), 3) + got = jit_func(ctypes.byref(param_ary), 3) print("got", got) # compare the result expect = example_1(ary, 3) @@ -447,7 +446,7 @@ def example_2(ary, size): if __name__ == "__main__": - jt = compiler_pipeline( + llvm_module, func_egraph = compiler.lower_py_fn( example_2, argtypes=(array_1d_symbolic, Int64), ruleset=( @@ -456,17 +455,15 @@ def example_2(ary, size): | ruleset(*array_infos) | ruleset_typeinfer_array_getitem ), - verbose=True, - converter_class=ExtendEGraphToRVSDG, - cost_model=MyCostModel(), - backend=Backend(), ) + jit_func = compiler.compile_module(llvm_module, func_egraph) + ary = np.arange(10, dtype=np.int64) param_ary = CtypeInt64Array1D() param_ary.ptr = ary.ctypes.data param_ary.shape[0] = ary.shape[0] - got = jt(ctypes.byref(param_ary), ary.size) + got = jit_func(ctypes.byref(param_ary), ary.size) print("got", got) expect = example_2(ary, ary.size) assert got == expect diff --git a/sealir-tutorials/ch06_mlir_backend.py b/sealir-tutorials/ch06_mlir_backend.py index 31937d7..5947b8a 100644 --- a/sealir-tutorials/ch06_mlir_backend.py +++ b/sealir-tutorials/ch06_mlir_backend.py @@ -27,12 +27,17 @@ import mlir.dialects.cf as cf import mlir.dialects.func as func import mlir.dialects.scf as scf +import mlir.runtime as runtime import mlir.execution_engine as execution_engine +import mlir.runtime as runtime + import mlir.ir as ir import mlir.passmanager as passmanager +import numba.cuda from sealir import ase from sealir.rvsdg import grammar as rg from sealir.rvsdg import internal_prefix +import numpy as np from ch03_egraph_program_rewrites import ( run_test, @@ -60,7 +65,7 @@ ) from ch04_1_typeinfer_ifelse import base_ruleset as if_else_ruleset from ch04_1_typeinfer_ifelse import ( - compiler_pipeline, + Compiler, ruleset_type_infer_float, setup_argtypes, ) @@ -73,6 +78,7 @@ from ch04_2_typeinfer_loops import base_ruleset as loop_ruleset from utils import IN_NOTEBOOK +_DEBUG = False @dataclass(frozen=True) class LowerStates(ase.TraverseState): @@ -88,6 +94,7 @@ class LowerStates(ase.TraverseState): class Backend: def __init__(self): self.context = context = ir.Context() + self.f32 = ir.F32Type.get(context=context) self.f64 = ir.F64Type.get(context=context) self.i32 = ir.IntegerType.get_signless(32, context=context) self.i64 = ir.IntegerType.get_signless(64, context=context) @@ -96,15 +103,12 @@ def __init__(self): def lower_type(self, ty: NbOp_Type): match ty: case NbOp_Type("Int64"): - return ir.IntType(64) - raise NotImplementedError(f"unknown type: {ty}") - - def get_mlir_type(self, seal_ty): - match seal_ty.name: - case "Int64": return self.i64 - case "Float64": + case NbOp_Type("Float64"): return self.f64 + case NbOp_Type("Float32"): + return self.f32 + raise NotImplementedError(f"unknown type: {ty}") def lower(self, root: rg.Func, argtypes): context = self.context @@ -113,11 +117,11 @@ def lower(self, root: rg.Func, argtypes): # Get the module body pointer so we can insert content into the # module. - module_body = ir.InsertionPoint(module.body) + self.module_body = module_body = ir.InsertionPoint(module.body) # Convert SealIR types to MLIR types. - input_types = tuple([self.get_mlir_type(x) for x in argtypes]) + input_types = tuple([self.lower_type(x) for x in argtypes]) output_types = ( - self.get_mlir_type( + self.lower_type( Attributes(root.body.begin.attrs).get_return_type(root.body) ), ) @@ -171,13 +175,24 @@ def get_region_args(): with context, loc, constant_entry: cf.br([], fun.body.blocks[1]) + return module + + def run_passes(self, module): module.dump() - return self.run_passes(module, context) - def run_passes(self, module, context): - pass_man = passmanager.PassManager(context=context) + if _DEBUG: + module.context.enable_multithreading(False) + if _DEBUG: + pass_man.enable_ir_printing() + + pass_man = passmanager.PassManager(context=module.context) + pass_man.add("convert-linalg-to-loops") pass_man.add("convert-scf-to-cf") + pass_man.add("finalize-memref-to-llvm") + pass_man.add("convert-math-to-libm") pass_man.add("convert-func-to-llvm") + pass_man.add("convert-index-to-llvm") + pass_man.add("reconcile-unrealized-casts") pass_man.enable_verifier(True) pass_man.run(module.operation) # Output LLVM-dialect MLIR @@ -381,46 +396,36 @@ def lower_expr(self, expr: SExpr, state: LowerStates): case _: raise NotImplementedError(expr, type(expr)) - def jit_compile(self, llmod, func_node: rg.Func): + def jit_compile(self, llmod, func_node: rg.Func, func_name="func"): attributes = Attributes(func_node.body.begin.attrs) # Convert SealIR types into MLIR types - input_types = tuple( - [self.get_mlir_type(x) for x in attributes.input_types()] - ) + with self.loc: + input_types = tuple( + [self.lower_type(x) for x in attributes.input_types()] + ) output_types = ( - self.get_mlir_type( + self.lower_type( Attributes(func_node.body.begin.attrs).get_return_type( func_node.body ) ), ) - # Converts the MLIR module into a JIT-callable function. - return JitCallable.from_pointer(llmod, input_types, output_types) - - -def get_exec_ptr(mlir_ty, val): - if isinstance(mlir_ty, ir.IntegerType): - return ctypes.pointer(ctypes.c_int64(val)) - elif isinstance(mlir_ty, ir.F32Type): - return ctypes.pointer(ctypes.c_float(val)) - elif isinstance(mlir_ty, ir.F64Type): - return ctypes.pointer(ctypes.c_double(val)) - - -@dataclass(frozen=True) -class JitCallable: - jit_func: Callable + return self.jit_compile_(llmod, input_types, output_types) @classmethod - def from_pointer(cls, jit_module, input_types, output_types): + def jit_compile_(cls, llmod, input_types, output_types, function_name="func", exec_engine=None, **execution_engine_params): + # Converts the MLIR module into a JIT-callable function. # Use MLIR's own internal execution engine - engine = execution_engine.ExecutionEngine(jit_module) - + if exec_engine is None: + engine = execution_engine.ExecutionEngine(llmod, **execution_engine_params) + else: + engine = exec_engine + assert ( len(output_types) == 1 ), "Execution of functions with output arguments > 1 not supported" - res_ptr = get_exec_ptr(output_types[0], 0) + res_ptr, res_val = cls.get_exec_ptr(output_types[0], None) # Build a wrapper function def jit_func(*input_args): @@ -434,7 +439,7 @@ def jit_func(*input_args): # the internal execution engine should # be C-Type pointers. input_exec_ptrs = [ - get_exec_ptr(ty, val) + cls.get_exec_ptr(ty, val)[0] for ty, val in zip(input_types, input_args) ] # Invokes the function that we built, internally calls @@ -443,13 +448,39 @@ def jit_func(*input_args): # appended to the end of all input pointers in the invoke call. engine.invoke(function_name, *input_exec_ptrs, res_ptr) - return res_ptr.contents.value - - return cls(jit_func) + return cls.get_out_val(res_ptr, res_val) - def __call__(self, *args: Any) -> Any: - return self.jit_func(*args) + return jit_func + @classmethod + def get_exec_ptr(cls, mlir_ty, val): + if isinstance(mlir_ty, ir.IntegerType): + val = 0 if val is None else val + ptr = ctypes.pointer(ctypes.c_int64(val)) + elif isinstance(mlir_ty, ir.F32Type): + val = 0.0 if val is None else val + ptr = ctypes.pointer(ctypes.c_float(val)) + elif isinstance(mlir_ty, ir.F64Type): + val = 0.0 if val is None else val + ptr = ctypes.pointer(ctypes.c_double(val)) + elif isinstance(mlir_ty, ir.MemRefType): + if isinstance(mlir_ty.element_type, ir.F64Type): + np_dtype = np.float64 + elif isinstance(mlir_ty.element_type, ir.F32Type): + np_dtype = np.float32 + else: + raise TypeError("The current array element type is not supported") + val = np.zeros(mlir_ty.shape, dtype=np_dtype) if val is None else val + ptr = ctypes.pointer(ctypes.pointer(runtime.get_ranked_memref_descriptor(val))) + + return ptr, val + + @classmethod + def get_out_val(cls, res_ptr, res_val): + if isinstance(res_val, np.ndarray): + return res_val + else: + return res_ptr.contents.value # + [markdown] jp-MarkdownHeadingCollapsed=true # Example 1: simple if-else @@ -463,21 +494,20 @@ def example_1(a, b): z = b - a return z + a +compiler = Compiler(ConditionalExtendGraphtoRVSDG, Backend(), MyCostModel(), True) if __name__ == "__main__": - jt = compiler_pipeline( + llvm_module, func_egraph = compiler.lower_py_fn( example_1, argtypes=(Int64, Int64), ruleset=(if_else_ruleset | setup_argtypes(TypeInt64, TypeInt64)), - verbose=True, - converter_class=ConditionalExtendGraphtoRVSDG, - cost_model=MyCostModel(), - backend=Backend(), ) + compiler.run_backend_passes(llvm_module) + jit_func = compiler.compile_module(llvm_module, func_egraph) args = (10, 33) - run_test(example_1, jt, args, verbose=True) + run_test(example_1, jit_func, args, verbose=True) args = (7, 3) - run_test(example_1, jt, args, verbose=True) + run_test(example_1, jit_func, args, verbose=True) # ## Example 2: add `float()` @@ -495,7 +525,7 @@ def example_2(a, b): if __name__ == "__main__": - jt = compiler_pipeline( + llvm_module, func_egraph = compiler.lower_py_fn( example_2, argtypes=(Int64, Int64), ruleset=( @@ -503,15 +533,13 @@ def example_2(a, b): | setup_argtypes(TypeInt64, TypeInt64) | ruleset_type_infer_float # < --- added for float() ), - verbose=True, - converter_class=ConditionalExtendGraphtoRVSDG, - cost_model=MyCostModel(), - backend=Backend(), ) + compiler.run_backend_passes(llvm_module) + jit_func = compiler.compile_module(llvm_module, func_egraph) args = (10, 33) - run_test(example_2, jt, args, verbose=True) + run_test(example_2, jit_func, args, verbose=True) args = (7, 3) - run_test(example_2, jt, args, verbose=True) + run_test(example_2, jit_func, args, verbose=True) # ## Example 3: Simple while loop example @@ -524,18 +552,18 @@ def example_3(init, n): i = i + 1 return c +compiler = Compiler(LoopExtendEGraphToRVSDG, Backend(), MyCostModel(), True) if __name__ == "__main__": - jt = compiler_pipeline( + llvm_module, func_egraph = compiler.lower_py_fn( example_3, argtypes=(Int64, Int64), ruleset=loop_ruleset | setup_argtypes(TypeInt64, TypeInt64), - verbose=True, - converter_class=LoopExtendEGraphToRVSDG, - cost_model=MyCostModel(), - backend=Backend(), ) - run_test(example_3, jt, (10, 7), verbose=True) + compiler.run_backend_passes(llvm_module) + jit_func = compiler.compile_module(llvm_module, func_egraph) + + run_test(example_3, jit_func, (10, 7), verbose=True) # ## Example 4: Nested Loop example @@ -554,13 +582,12 @@ def example_4(init, n): if __name__ == "__main__": - jt = compiler_pipeline( + llvm_module, func_egraph = compiler.lower_py_fn( example_4, argtypes=(Int64, Int64), ruleset=loop_ruleset | setup_argtypes(TypeInt64, TypeInt64), - verbose=True, - converter_class=LoopExtendEGraphToRVSDG, - cost_model=MyCostModel(), - backend=Backend(), ) - run_test(example_4, jt, (10, 7), verbose=True) + compiler.run_backend_passes(llvm_module) + jit_func = compiler.compile_module(llvm_module, func_egraph) + + run_test(example_4, jit_func, (10, 7), verbose=True) diff --git a/sealir-tutorials/ch07_mlir_ufunc.py b/sealir-tutorials/ch07_mlir_ufunc.py new file mode 100644 index 0000000..11685d0 --- /dev/null +++ b/sealir-tutorials/ch07_mlir_ufunc.py @@ -0,0 +1,166 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.16.7 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# # Ch 7. MLIR ufunc operations +# + +# In this chapter, we'll look at type inference for array operations. + +from __future__ import annotations +import inspect + +import numpy as np +from ch04_2_typeinfer_loops import ( + MyCostModel, + Compiler, + setup_argtypes, +) +from ch05_typeinfer_array import NbOp_ArrayType, NbOp_ArrayDimSymbolic, Type, base_ruleset +from ch06_mlir_backend import ConditionalExtendGraphtoRVSDG, Backend as _Backend, NbOp_Type + +import mlir.dialects.linalg as linalg +import mlir.dialects.func as func +import mlir.ir as ir + +# Type declaration for array elements +Float64 = NbOp_Type("Float64") +TypeFloat64 = Type.simple("Float64") +Float32 = NbOp_Type("Float32") +TypeFloat32 = Type.simple("Float32") + +class Backend(_Backend): + # Lower symbolic array to respective memref. + # Note: This is not used within ufunc builder, + # since it has explicit declaration of the respective + # MLIR memrefs. + def lower_type(self, ty: NbOp_Type): + match ty: + case NbOp_ArrayType( + dtype=dtype, + ndim=int(ndim), + datalayout=str(datalayout), + shape=shape, + ): + mlir_dtype = self.lower_type(dtype) + with self.loc: + memref_ty=ir.MemRefType.get(shape, mlir_dtype) + return memref_ty + return super().lower_type(ty) + +# Decorator function for vecotrization. +def ufunc_vectorize(input_type, shape, ufunc_compiler, extra_ruleset=None): + def to_input_dtypes(ty): + if ty == Float64: + return TypeFloat64 + elif ty == Float32: + return TypeFloat32 + + def wrapper(inner_func): + sig = inspect.signature(inner_func) + num_inputs = len(sig.parameters) + ruleset = ( + base_ruleset + | setup_argtypes(*(to_input_dtypes(input_type),)*num_inputs) + ) + if extra_ruleset is not None: + ruleset |= extra_ruleset + # Compile the inner function and get the IR as a module. + llmod, func_egraph = ufunc_compiler.lower_py_fn( + inner_func, + argtypes=(input_type,)*num_inputs, + ruleset=ruleset, + ) + + # Now within the module declare a seperate function named + # 'ufunc' which acts as a wrapper around the innner 'func' + with llmod.context, ir.Location.unknown(context=llmod.context), ir.InsertionPoint(llmod.body): + f32 = ir.F32Type.get() + f64 = ir.F64Type.get() + + match input_type.name: + case 'Float32': + internal_dtype = f32 + case 'Float64': + internal_dtype = f64 + case _ : + raise TypeError("The current input type is not supported") + + ndim = len(shape) + memref_ty = ir.MemRefType.get(shape, internal_dtype) + + # The function 'ufunc' has N + 1 number of arguments + # (where N is the nuber of arguments for the original function) + # The extra argument is an explicitly declared resulting array. + input_typ_outer = (memref_ty,) * (num_inputs + 1) + + fun = func.FuncOp("ufunc", (input_typ_outer, ())) + fun.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + const_block = fun.add_entry_block() + constant_entry = ir.InsertionPoint(const_block) + + # Within this function we declare the symbolic representation of + # input and output arrays of appropriate shapes using memrefs. + with constant_entry: + arys = fun.arguments[:-1] + res = fun.arguments[-1] + + # Affine map declaration + indexing_maps = ir.ArrayAttr.get([ + ir.AffineMapAttr.get(ir.AffineMap.get(ndim, 0, [ + ir.AffineExpr.get_dim(i) for i in range(ndim) + ])), + ] * (num_inputs + 1)) + iterators = ir.ArrayAttr.get([ + ir.Attribute.parse(f"#linalg.iterator_type") + ] * (num_inputs)) + matmul = linalg.GenericOp( + result_tensors=[], + inputs=arys, + outputs=[res], + indexing_maps=indexing_maps, + iterator_types=iterators + ) + # Within the affine loop body make calls to the inner function. + body = matmul.regions[0].blocks.append(*([internal_dtype] * (num_inputs + 1))) + with ir.InsertionPoint(body): + m = func.CallOp([internal_dtype], "func", [*body.arguments[:-1]]) + linalg.YieldOp([m]) + func.ReturnOp([]) + + ufunc_compiler.run_backend_passes(llmod) + + jit_func = ufunc_compiler.compile_module_(llmod, [memref_ty] * num_inputs, (memref_ty,), "ufunc") + return jit_func + + return wrapper + + +if __name__ == "__main__": + compiler = Compiler(ConditionalExtendGraphtoRVSDG, Backend(), MyCostModel(), True) + + @ufunc_vectorize(input_type=Float64, shape=(10, 10), ufunc_compiler=compiler) + def foo(a, b, c): + x = a + 1.0 + y = b - 2.0 + z = c + 3.0 + return x + y + z + + # Create NumPy arrays + ary = np.arange(100, dtype=np.float64).reshape(10, 10) + ary_2 = np.arange(100, dtype=np.float64).reshape(10, 10) + ary_3 = np.arange(100, dtype=np.float64).reshape(10, 10) + + got = foo(ary, ary_2, ary_3) + print("Got", got) + diff --git a/sealir-tutorials/ch08_gpu_offload.py b/sealir-tutorials/ch08_gpu_offload.py new file mode 100644 index 0000000..bbd9d1e --- /dev/null +++ b/sealir-tutorials/ch08_gpu_offload.py @@ -0,0 +1,159 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.16.7 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# # Ch 8. GPU offloading for MLIR ufunc operations +# + +# In this chapter, we'll look at type inference for array operations. + +from __future__ import annotations + +import numpy as np +from ch04_2_typeinfer_loops import ( + MyCostModel, + Compiler, +) +from ch05_typeinfer_array import NbOp_ArrayType +from ch06_mlir_backend import ConditionalExtendGraphtoRVSDG, Backend as _Backend, NbOp_Type +from ch07_mlir_ufunc import ufunc_vectorize, Float64 + +import mlir.ir as ir +import mlir.execution_engine as execution_engine +import mlir.runtime as runtime +import mlir.passmanager as passmanager +import ctypes +from ctypes.util import find_library +from numba import cuda +from collections import namedtuple + +_DEBUG = True + +class GPUBackend(_Backend): + # Lower symbolic array to respective memref. + # Note: This is not used within ufunc builder, + # since it has explicit declaration of the respective + # MLIR memrefs. + def lower_type(self, ty: NbOp_Type): + match ty: + case NbOp_ArrayType( + dtype=dtype, + ndim=int(ndim), + datalayout=str(datalayout), + shape=shape, + ): + mlir_dtype = self.lower_type(dtype) + with self.loc: + memref_ty=ir.MemRefType.get(shape, mlir_dtype) + return memref_ty + return super().lower_type(ty) + + def run_passes(self, module): + module.dump() + pass_man = passmanager.PassManager(context=module.context) + + if _DEBUG: + module.context.enable_multithreading(False) + if _DEBUG: + pass_man.enable_ir_printing() + + pass_man.add("convert-linalg-to-affine-loops") + pass_man.add("affine-loop-fusion") + pass_man.add("inline") + pass_man.add("func.func(affine-parallelize)") + pass_man.add("builtin.module(func.func(gpu-map-parallel-loops,convert-parallel-loops-to-gpu))") + pass_man.add("lower-affine") + pass_man.add("scf-parallel-loop-fusion") + pass_man.add('func.func(gpu-map-parallel-loops,convert-parallel-loops-to-gpu)') + pass_man.add("gpu-kernel-outlining") + pass_man.add('gpu-lower-to-nvvm-pipeline{cubin-format="fatbin"}') + pass_man.add("convert-scf-to-cf") + pass_man.add("finalize-memref-to-llvm") + pass_man.add("convert-math-to-libm") + pass_man.add("convert-func-to-llvm") + pass_man.add("convert-index-to-llvm") + pass_man.add("convert-bufferization-to-memref") + pass_man.add("reconcile-unrealized-casts") + pass_man.add("func.func(llvm-request-c-wrappers)") + pass_man.enable_verifier(True) + pass_man.run(module.operation) + # Output LLVM-dialect MLIR + module.dump() + return module + + @classmethod + def get_exec_ptr(cls, mlir_ty, val): + if isinstance(mlir_ty, ir.IntegerType): + val = 0 if val is None else val + ptr = ctypes.pointer(ctypes.c_int64(val)) + elif isinstance(mlir_ty, ir.F32Type): + val = 0.0 if val is None else val + ptr = ctypes.pointer(ctypes.c_float(val)) + elif isinstance(mlir_ty, ir.F64Type): + val = 0.0 if val is None else val + ptr = ctypes.pointer(ctypes.c_double(val)) + elif isinstance(mlir_ty, ir.MemRefType): + if isinstance(mlir_ty.element_type, ir.F64Type): + np_dtype = np.float64 + elif isinstance(mlir_ty.element_type, ir.F32Type): + np_dtype = np.float32 + else: + raise TypeError("The current array element type is not supported") + val = np.zeros(mlir_ty.shape, dtype=np_dtype) if val is None else val + val = cls.np_arr_to_np_duck_device_arr(val) + ptr = ctypes.pointer(ctypes.pointer(runtime.get_ranked_memref_descriptor(val))) + + return ptr, val + + @classmethod + def get_out_val(cls, res_ptr, res_val): + if isinstance(res_val, cuda.cudadrv.devicearray.DeviceNDArray): + return res_val.copy_to_host() + else: + return super().get_out_val(res_ptr, res_val) + + @classmethod + def np_arr_to_np_duck_device_arr(cls, arr): + da = cuda.to_device(arr) + ctlie = namedtuple("ctypes_lie", "data data_as shape") + da.ctypes = ctlie(da.__cuda_array_interface__["data"][0], + lambda x: ctypes.cast(da.ctypes.data, x), + da.__cuda_array_interface__["shape"],) + da.itemsize = arr.itemsize + return da + + @classmethod + def jit_compile_(cls, llmod, input_types, output_types, function_name="func", exec_engine=None, **execution_engine_params): + cuda_libs = ("mlir_cuda_runtime", "mlir_c_runner_utils", "mlir_runner_utils") + cuda_shared_libs = [find_library(x) for x in cuda_libs] + return super().jit_compile_(llmod, input_types, output_types, function_name, exec_engine=execution_engine.ExecutionEngine(llmod, opt_level=3, shared_libs=cuda_shared_libs)) + + +if __name__ == "__main__": + gpu_compiler = Compiler(ConditionalExtendGraphtoRVSDG, GPUBackend(), MyCostModel(), True) + + @ufunc_vectorize(input_type=Float64, shape=(10, 10), ufunc_compiler=gpu_compiler) + def foo(a, b, c): + x = a + 1.0 + y = b - 2.0 + z = c + 3.0 + return x + y + z + + # Create NumPy arrays + ary = np.arange(100, dtype=np.float64).reshape(10, 10) + ary_2 = np.arange(100, dtype=np.float64).reshape(10, 10) + ary_3 = np.arange(100, dtype=np.float64).reshape(10, 10) + + got = foo(ary, ary_2, ary_3) + print("Got", got) + diff --git a/sealir-tutorials/demo01_gelu_tanh_approx.py b/sealir-tutorials/demo01_gelu_tanh_approx.py index 09f45b5..1ab8d16 100644 --- a/sealir-tutorials/demo01_gelu_tanh_approx.py +++ b/sealir-tutorials/demo01_gelu_tanh_approx.py @@ -65,16 +65,17 @@ ) from ch05_typeinfer_array import MyCostModel as ch06_CostModel from ch05_typeinfer_array import ( - NbOp_Type, base_ruleset, - compiler_pipeline, + Compiler, ) -from ch06_mlir_backend import Backend as ch06_Backend from ch06_mlir_backend import LowerStates, run_test +from ch07_mlir_ufunc import ufunc_vectorize, Float32, Backend as UfuncBackend, TypeFloat32 +from ch08_gpu_offload import GPUBackend +import mlir.dialects.arith as arith +import mlir.dialects.math as math # ## The GELU function - def gelu_tanh_forward(a): dt = np.float32 result = ( @@ -264,11 +265,6 @@ def ruleset_typeinfer_f32_ops(res: Term, x: Term, y: Term): # ### Extend the RVSDG Grammar -# + -TypeFloat32 = Type.simple("Float32") - -Float32 = NbOp_Type("Float32") - SExpr = rvsdg.grammar.SExpr @@ -292,7 +288,6 @@ class NbOp_Mul_Float32(NbOp_Base): lhs: SExpr rhs: SExpr - class NbOp_Div_Float32(NbOp_Base): lhs: SExpr rhs: SExpr @@ -357,7 +352,7 @@ def handle_Module( # ### Extend the backend -class Backend(ch06_Backend): +class Backend(UfuncBackend): def __init__(self): super().__init__() self.f32 = ir.F32Type.get(context=self.context) @@ -369,9 +364,6 @@ def get_mlir_type(self, seal_ty): return super().get_mlir_type(seal_ty) def lower_expr(self, expr: SExpr, state: LowerStates): - import mlir.dialects.arith as arith - import mlir.dialects.math as math - match expr: case NbOp_Add_Float32(lhs, rhs): lhs = yield lhs @@ -405,18 +397,6 @@ def lower_expr(self, expr: SExpr, state: LowerStates): return arith.constant(self.i32, 0) return (yield from super().lower_expr(expr, state)) - def run_passes(self, module, context): - import mlir.passmanager as passmanager - - pass_man = passmanager.PassManager(context=context) - pass_man.add("convert-scf-to-cf") - pass_man.add("convert-math-to-libm") - pass_man.add("convert-func-to-llvm") - pass_man.enable_verifier(True) - pass_man.run(module.operation) - module.dump() - return module - # ## Cost Model # @@ -440,21 +420,19 @@ def get_cost_function(self, nodename, op, ty, cost, children): # ### Run the baseline function +compiler = Compiler(ExtendEGraphToRVSDG, Backend(), MyCostModel(), True) if __name__ == "__main__": - jt = compiler_pipeline( + llvm_module, func_egraph = compiler.lower_py_fn( gelu_tanh_forward, argtypes=(Float32,), ruleset=( base_ruleset | setup_argtypes(TypeFloat32) | additional_rules ), - verbose=True, - converter_class=ExtendEGraphToRVSDG, - cost_model=MyCostModel(), - backend=Backend(), ) - run_test(gelu_tanh_forward, jt, (0.234,), verbose=True) - + compiler.run_backend_passes(llvm_module) + jit_func = compiler.compile_module(llvm_module, func_egraph) + run_test(gelu_tanh_forward, jit_func, (0.234,), verbose=True) # ## Add rules to optimize @@ -502,7 +480,7 @@ def pow_expansion(x: Term, ival: i64): # ### Run the optimized function if __name__ == "__main__": - jt = compiler_pipeline( + llvm_module, func_egraph = compiler.lower_py_fn( gelu_tanh_forward, argtypes=(Float32,), ruleset=( @@ -511,11 +489,9 @@ def pow_expansion(x: Term, ival: i64): | additional_rules | optimize_rules ), - verbose=True, - converter_class=ExtendEGraphToRVSDG, - cost_model=MyCostModel(), - backend=Backend(), ) + compiler.run_backend_passes(llvm_module) + jit_func = compiler.compile_module(llvm_module, func_egraph) # ### Compare the result # @@ -524,4 +500,19 @@ def pow_expansion(x: Term, ival: i64): if __name__ == "__main__": relclose = lambda x, y: np.allclose(x, y, rtol=1e-6) - run_test(gelu_tanh_forward, jt, (0.234,), equal=relclose) + run_test(gelu_tanh_forward, jit_func, (0.234,), equal=relclose) + + vectorized_gelu = ufunc_vectorize(input_type=Float32, shape=(10,), ufunc_compiler=compiler, extra_ruleset=additional_rules)(gelu_tanh_forward) + relclose = lambda x, y: np.allclose(x, y, rtol=1e-6) + input_val = np.array([0.234]*10, dtype=np.float32) + run_test(gelu_tanh_forward, vectorized_gelu, (input_val,), equal=relclose) + + class Backend2(Backend, GPUBackend): + pass + + gpu_compiler = Compiler(ExtendEGraphToRVSDG, Backend2(), MyCostModel(), True) + + vectorized_gelu = ufunc_vectorize(input_type=Float32, shape=(10,), ufunc_compiler=gpu_compiler, extra_ruleset=additional_rules)(gelu_tanh_forward) + relclose = lambda x, y: np.allclose(x, y, rtol=1e-6) + input_val = np.array([0.234]*10, dtype=np.float32) + run_test(gelu_tanh_forward, vectorized_gelu, (input_val,), equal=relclose)