Skip to content

Chapter 7: MLIR based ufunc implementations #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions sealir-tutorials/ch04_1_typeinfer_ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def compiler_pipeline(
converter_class=EGraphToRVSDG,
cost_model=None,
backend,
return_module = False,
):

rvsdg_expr, dbginfo = frontend(fn)
Expand Down Expand Up @@ -550,10 +551,12 @@ def define_egraph(
print("cost =", cost)
print(format_rvsdg(extracted))

llmod = backend.lower(extracted, argtypes)
llmod = backend.lower(extracted, argtypes, ignore_passes=return_module)
if verbose:
print("LLVM module".center(80, "="))
print(llmod)
if return_module:
return llmod
return backend.jit_compile(llmod, extracted)


Expand Down Expand Up @@ -1071,7 +1074,7 @@ def lower_cast(self, builder, value, fromty, toty):
f"unsupported lower_cast: {fromty} -> {toty}"
)

def lower(self, root: rg.Func, argtypes):
def lower(self, root: rg.Func, argtypes, ignore_passes=False):
mod = ir.Module()
llargtypes = [*map(self.lower_type, argtypes)]

Expand Down
45 changes: 28 additions & 17 deletions sealir-tutorials/ch06_mlir_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,16 @@
import mlir.dialects.cf as cf
import mlir.dialects.func as func
import mlir.dialects.scf as scf
import mlir.runtime as runtime
import mlir.execution_engine as execution_engine
import mlir.runtime as runtime

import mlir.ir as ir
import mlir.passmanager as passmanager
from sealir import ase
from sealir.rvsdg import grammar as rg
from sealir.rvsdg import internal_prefix
import numpy as np

from ch03_egraph_program_rewrites import (
run_test,
Expand Down Expand Up @@ -96,28 +100,23 @@ def __init__(self):
def lower_type(self, ty: NbOp_Type):
match ty:
case NbOp_Type("Int64"):
return ir.IntType(64)
raise NotImplementedError(f"unknown type: {ty}")

def get_mlir_type(self, seal_ty):
match seal_ty.name:
case "Int64":
return self.i64
case "Float64":
case NbOp_Type("Float64"):
return self.f64
raise NotImplementedError(f"unknown type: {ty}")

def lower(self, root: rg.Func, argtypes):
def lower(self, root: rg.Func, argtypes, ignore_passes=False):
context = self.context
self.loc = loc = ir.Location.unknown(context=context)
self.module = module = ir.Module.create(loc=loc)

# Get the module body pointer so we can insert content into the
# module.
module_body = ir.InsertionPoint(module.body)
self.module_body = module_body = ir.InsertionPoint(module.body)
# Convert SealIR types to MLIR types.
input_types = tuple([self.get_mlir_type(x) for x in argtypes])
input_types = tuple([self.lower_type(x) for x in argtypes])
output_types = (
self.get_mlir_type(
self.lower_type(
Attributes(root.body.begin.attrs).get_return_type(root.body)
),
)
Expand Down Expand Up @@ -172,12 +171,21 @@ def get_region_args():
cf.br([], fun.body.blocks[1])

module.dump()
return self.run_passes(module, context)
if ignore_passes:
return module
else:
return self.run_passes(module, context)

def run_passes(self, module, context):
pass_man = passmanager.PassManager(context=context)

pass_man.add("convert-linalg-to-loops")
pass_man.add("convert-scf-to-cf")
pass_man.add("finalize-memref-to-llvm")
pass_man.add("convert-func-to-llvm")
pass_man.add("convert-index-to-llvm")
pass_man.add("reconcile-unrealized-casts")

pass_man.enable_verifier(True)
pass_man.run(module.operation)
# Output LLVM-dialect MLIR
Expand Down Expand Up @@ -384,12 +392,13 @@ def lower_expr(self, expr: SExpr, state: LowerStates):
def jit_compile(self, llmod, func_node: rg.Func):
attributes = Attributes(func_node.body.begin.attrs)
# Convert SealIR types into MLIR types
input_types = tuple(
[self.get_mlir_type(x) for x in attributes.input_types()]
)
with self.loc:
input_types = tuple(
[self.lower_type(x) for x in attributes.input_types()]
)

output_types = (
self.get_mlir_type(
self.lower_type(
Attributes(func_node.body.begin.attrs).get_return_type(
func_node.body
)
Expand All @@ -406,6 +415,8 @@ def get_exec_ptr(mlir_ty, val):
return ctypes.pointer(ctypes.c_float(val))
elif isinstance(mlir_ty, ir.F64Type):
return ctypes.pointer(ctypes.c_double(val))
elif isinstance(mlir_ty, ir.MemRefType):
return ctypes.pointer(ctypes.pointer(runtime.get_ranked_memref_descriptor(val)))


@dataclass(frozen=True)
Expand Down Expand Up @@ -443,7 +454,7 @@ def jit_func(*input_args):
# appended to the end of all input pointers in the invoke call.
engine.invoke(function_name, *input_exec_ptrs, res_ptr)

return res_ptr.contents.value
return res_val

return cls(jit_func)

Expand Down
225 changes: 225 additions & 0 deletions sealir-tutorials/ch07_mlir_ufunc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.16.7
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---

# # Ch 7. MLIR ufunc operations
#

# In this chapter, we'll look at type inference for array operations.

from __future__ import annotations

import ctypes

import numpy as np
from ch04_1_typeinfer_ifelse import TypeFloat64
from ch04_2_typeinfer_loops import (
MyCostModel,
base_ruleset,
compiler_pipeline,
setup_argtypes,
)
from ch05_typeinfer_array import NbOp_ArrayType, NbOp_ArrayDimSymbolic, Type
from ch06_mlir_backend import ConditionalExtendGraphtoRVSDG, Backend as _Backend, NbOp_Type

import mlir.dialects.arith as arith
import mlir.dialects.math as math
import mlir.dialects.memref as memref
import mlir.dialects.linalg as linalg
import mlir.dialects.cf as cf
import mlir.dialects.func as func
import mlir.dialects.scf as scf
import mlir.execution_engine as execution_engine
import mlir.ir as ir
import mlir.runtime as runtime
import mlir.passmanager as passmanager

# Type declaration for array elements
Float64 = NbOp_Type("Float64")
TypeFloat64 = Type.simple("Float64")

# Define an array using the Float64 dtypes
# and symbolic dimensions (m, n)
array_2d_symbolic = NbOp_ArrayType(
dtype=Float64,
ndim=2,
datalayout="c_contiguous",
shape=(NbOp_ArrayDimSymbolic("m"), NbOp_ArrayDimSymbolic("n")),
)

class Backend(_Backend):
# Lower symbolic array to respective memref.
# Note: This is not used within ufunc builder,
# since it has explicit declaration of the respective
# MLIR memrefs.
def lower_type(self, ty: NbOp_Type):
match ty:
case NbOp_ArrayType(
dtype=dtype,
ndim=int(ndim),
datalayout=str(datalayout),
shape=shape,
):
mlir_dtype = self.lower_type(dtype)
with self.loc:
memref_ty=ir.MemRefType.get(shape, mlir_dtype)
return memref_ty
return super().lower_type(ty)


# Decorator function for vecotrization.
def ufunc_vectorize(input_types, shape=None, ndim=None):
num_inputs = len(input_types)

def to_input_dtypes(input_tys):
res = []
for ty in input_tys:
if ty == Float64:
res.append(TypeFloat64)
return tuple(res)

def wrapper(inner_func):
nonlocal ndim
# Compile the inner function and get the IR as a module.
llmod = compiler_pipeline(
inner_func,
argtypes=input_types,
ruleset=(
base_ruleset
| setup_argtypes(*to_input_dtypes(input_types))
),
verbose=True,
converter_class=ConditionalExtendGraphtoRVSDG,
cost_model=MyCostModel(),
backend=Backend(),
return_module=True
)

# Now within the module declare a seperate function named
# 'ufunc' which acts as a wrapper around the innner 'func'

module_body = ir.InsertionPoint(llmod.body)
context = llmod.context
loc = ir.Location.unknown(context=context)

f64 = ir.F64Type.get(context=context)
index_type = ir.IndexType.get(context=context)

with context, loc:
if ndim is not None:
memref_ty = ir.MemRefType.get([ir.ShapedType.get_dynamic_size()] * ndim, f64)
elif shape is not None:
ndim = len(shape)
memref_ty = ir.MemRefType.get(shape, f64)

# The function 'ufunc' has N + 1 number of arguments
# (where N is the nuber of arguments for the original function)
# The extra argument is an explicitly declared resulting array.
input_typ_outer = (memref_ty,) * (num_inputs + 1)
with context, loc, module_body:
fun = func.FuncOp("ufunc", (input_typ_outer, ()))
fun.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
const_block = fun.add_entry_block()
constant_entry = ir.InsertionPoint(const_block)

# Within this function we declare the symbolic representation of
# input and output arrays of appropriate shapes using memrefs.
with constant_entry:
arys = fun.arguments[:-1]
res = fun.arguments[-1]

# Affine map declaration
indexing_maps = ir.ArrayAttr.get([
ir.AffineMapAttr.get(ir.AffineMap.get(ndim, 0, [
ir.AffineExpr.get_dim(i) for i in range(ndim)
])),
] * (num_inputs + 1))
iterators = ir.ArrayAttr.get([
ir.Attribute.parse(f"#linalg.iterator_type<parallel>")
] * (num_inputs + 1))
matmul = linalg.GenericOp(
result_tensors=[],
inputs=arys,
outputs=[res],
indexing_maps=indexing_maps,
iterator_types=iterators
)
# Within the affine loop body make calls to the inner function.
body = matmul.regions[0].blocks.append(*([f64] * num_inputs))
with ir.InsertionPoint(body):
m = func.CallOp([f64], "func", [*body.arguments])
linalg.YieldOp([m])
func.ReturnOp([])

llmod.dump()
pass_man = passmanager.PassManager(context=context)

# pass_man.add("lower-affine")
# pass_man.add("convert-tensor-to-linalg")
# pass_man.add("convert-linalg-to-affine-loops")
# pass_man.add("affine-loop-fusion")
# pass_man.add("affine-parallelize")

pass_man.add("convert-linalg-to-loops")
pass_man.add("convert-scf-to-cf")
pass_man.add("finalize-memref-to-llvm")
pass_man.add("convert-func-to-llvm")
pass_man.add("convert-index-to-llvm")
pass_man.add("reconcile-unrealized-casts")
pass_man.enable_verifier(True)
pass_man.run(llmod.operation)
llmod.dump()

engine = execution_engine.ExecutionEngine(llmod)

def inner_wrapper(*args):
nonlocal shape
nonlocal ndim
assert len(args) == num_inputs, "Number of provided arguments doesn't match definition"
if shape is not None:
for arg in args:
assert arg.shape == shape, "Provided shape doesn't match ufunc definition"
elif ndim is not None:
shape = args[0].shape
for arg in args:
assert arg.ndim == ndim, "Provided ndim doesn't match ufunc definition"
assert arg.shape == shape, "Provided arguments have different shapes than each other, this is currently not supported"

# Declare the resulting NumPy array of same dtype.
res_array = np.zeros_like(args[0])
engine_args = [ctypes.pointer(ctypes.pointer(runtime.get_ranked_memref_descriptor(arg))) for arg in (*args, res_array)]
# Invoke function 'ufunc' using memref descriptor representing NumPy arrays.
engine.invoke("ufunc", *engine_args)
return res_array

return inner_wrapper

return wrapper


@ufunc_vectorize(input_types=[Float64, Float64, Float64], ndim=2)
def foo(a, b, c):
x = a + 1.0
y = b - 2.0
z = c + 3.0
return x + y + z

if __name__ == "__main__":
# Create NumPy arrays
ary = np.arange(100, dtype=np.float64).reshape(10, 10)
ary_2 = np.arange(100, dtype=np.float64).reshape(10, 10)
ary_3 = np.arange(100, dtype=np.float64).reshape(10, 10)

got = foo(ary, ary_2, ary_3)
print("Got", got)