diff --git a/.gitignore b/.gitignore index aba1851..238af48 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ build*/ Brutus/Manifest.toml llvm julia +MLIR.jl +Brutus/dev diff --git a/Brutus/Project.toml b/Brutus/Project.toml index 9e5698d..9520d08 100644 --- a/Brutus/Project.toml +++ b/Brutus/Project.toml @@ -5,6 +5,7 @@ version = "0.1.0" [deps] GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" +MLIR = "bfde9dd4-8f40-4a1e-be09-1475335e1c92" [compat] julia = "1.5" diff --git a/Brutus/scratch/juliacodegen.jl b/Brutus/scratch/juliacodegen.jl new file mode 100644 index 0000000..e864376 --- /dev/null +++ b/Brutus/scratch/juliacodegen.jl @@ -0,0 +1,65 @@ +module JuliaCodegen + +using Brutus +using MLIR + +println("\n---- brutus_id ----\n") +function brutus_id(N) + return N +end + +v = Brutus.call(brutus_id, 5) +@time v = Brutus.call(brutus_id, 5) +display(v) + +println("\n---- brutus_add ----\n") + +function brutus_add(N1, N2) + return N1 + N2 +end + +v = Brutus.call(brutus_add, 5.0, 10.0) +@time v = Brutus.call(brutus_add, 5.0, 10.0) +display(v) + +println("\n---- structs ----\n") + +struct Foo + x +end + +function bar() + f = Foo(5.0) + b = Foo(f.x + 10.0) + return f +end + +v = Brutus.call(bar; dump_options = Brutus.DumpAll) +display(v) + +#println("\n---- switch ----\n") +# +#function switch(N) +# N > 10 ? 5 : 10 +#end +# +#v = Brutus.call(switch, 15) +#display(v) + +#println("\n---- gauss ----\n") +# +#function gauss(N) +# k = 0 +# for i in 1 : N +# k += i +# end +# return k +#end +# +#mi = Brutus.get_methodinstance(Tuple{typeof(gauss), Int}) +#ir_code, rt = Brutus.code_ircode(mi) +#display(ir_code) +#mod = Brutus.Compiler.codegen_jlir(ir_code, rt, "gauss") +#MLIR.IR.dump(mod) + +end # module diff --git a/Brutus/src/Brutus.jl b/Brutus/src/Brutus.jl index 3469b81..449459d 100644 --- a/Brutus/src/Brutus.jl +++ b/Brutus/src/Brutus.jl @@ -11,7 +11,7 @@ import GPUCompiler: AbstractCompilerTarget, AbstractCompilerParams export emit include("init.jl") -include("codegen.jl") +include("compiler/Compiler.jl") include("reflection.jl") include("interface.jl") diff --git a/Brutus/src/codegen.jl b/Brutus/src/codegen.jl deleted file mode 100644 index 4fd1174..0000000 --- a/Brutus/src/codegen.jl +++ /dev/null @@ -1,105 +0,0 @@ -##### -##### Codegen -##### - -struct BrutusCompilerTarget <: AbstractCompilerTarget end -GPUCompiler.llvm_triple(::BrutusCompilerTarget) = Sys.MACHINE -GPUCompiler.llvm_machine(::BrutusCompilerTarget) = tm[] - -module Runtime - # the runtime library - signal_exception() = return - malloc(sz) = Base.Libc.malloc(sz) - report_oom(sz) = return - report_exception(ex) = return - report_exception_name(ex) = return - report_exception_frame(idx, func, file, line) = return -end - -@enum DumpOption::UInt8 begin - DumpIRCode = 0 - DumpTranslated = 1 - DumpCanonicalized = 2 - DumpLoweredToStd = 4 - DumpLoweredToLLVM = 8 - DumpTranslateToLLVM = 16 -end - -struct BrutusCompilerParams <: AbstractCompilerParams - emit_fptr::Bool - dump_options::Vector{DumpOption} -end - -GPUCompiler.ci_cache(job::CompilerJob{BrutusCompilerTarget}) = GLOBAL_CI_CACHE -GPUCompiler.runtime_module(job::CompilerJob{BrutusCompilerTarget}) = Runtime -GPUCompiler.isintrinsic(::CompilerJob{BrutusCompilerTarget}, fn::String) = true -GPUCompiler.can_throw(::CompilerJob{BrutusCompilerTarget}) = true -GPUCompiler.runtime_slug(job::CompilerJob{BrutusCompilerTarget}) = "brutus" - -function find_invokes(IR) - callees = Core.MethodInstance[] - for stmt in IR.stmts - if stmt isa Expr - if stmt.head == :invoke - mi = stmt.args[1] - push!(callees, mi) - end - end - end - return callees -end - -# Emit MLIR IR to stdout -function emit(job::CompilerJob) - ft = job.source.f - tt = job.source.tt - emit_fptr = job.params.emit_fptr - dump_options = job.params.dump_options - name = (ft <: Function) ? nameof(ft.instance) : nameof(ft) - - # get first method instance matching signature - entry_mi = get_methodinstance(Tuple{ft, tt.parameters...}) - IR, rt = code_ircode(entry_mi) - - if DumpIRCode in dump_options - println("return type: ", rt) - println("IRCode:\n") - println(IR) - end - - worklist = [IR] - methods = Dict{Core.MethodInstance, Tuple{Core.Compiler.IRCode, Any}}( - entry_mi => (IR, rt) - ) - - while !isempty(worklist) - code = pop!(worklist) - callees = find_invokes(code) - for callee in callees - if !haskey(methods, callee) - _code, _rt = code_ircode(callee) - - methods[callee] = (_code, _rt) - push!(worklist, _code) - end - end - end - - # generate LLVM bitcode and load it - dump_flags = reduce(|, map(UInt8, dump_options), init=0) - fptr = ccall((:brutus_codegen, "libbrutus"), - Ptr{Nothing}, - (Any, Any, Cuchar, Cuchar), - methods, entry_mi, emit_fptr, dump_flags) - return (fptr, rt) -end - -function emit(@nospecialize(ft), @nospecialize(tt); - emit_fptr::Bool=true, - dump_options::Vector{DumpOption}=DumpOption[]) - fspec = GPUCompiler.FunctionSpec(ft, Tuple{tt...}, false, nothing) - target = BrutusCompilerTarget() - params = BrutusCompilerParams(emit_fptr, dump_options) - job = CompilerJob(target, fspec, params) - return emit(job) -end diff --git a/Brutus/src/compiler/Compiler.jl b/Brutus/src/compiler/Compiler.jl new file mode 100644 index 0000000..ec58fc7 --- /dev/null +++ b/Brutus/src/compiler/Compiler.jl @@ -0,0 +1,11 @@ +module Compiler + +using MLIR +import MLIR.IR as JLIR +import Base: push! + +include("opbuilder.jl") +include("jlirgen.jl") +include("codegen.jl") + +end # module diff --git a/Brutus/src/compiler/codegen.jl b/Brutus/src/compiler/codegen.jl new file mode 100644 index 0000000..29b88f4 --- /dev/null +++ b/Brutus/src/compiler/codegen.jl @@ -0,0 +1,282 @@ +##### +##### Codegen +##### + +# This is the Julia interface between Julia's IRCode and JLIR. + +function maybe_widen_type(b::JLIRBuilder, loc::JLIR.Location, + jlir_value::JLIR.Value, expected_type::Type) + jlir_type = JLIR.get_type(jlir_value) + type = convert_jlirtype_to_type(jlir_type) + if (type != expected_type && type <: expected_type) + jlir_expected_type = convert_type_to_jlirtype(b.ctx, expected_type) + op = create!(b, PiOp(), loc, jlir_value, jlir_expected_type) + return JLIR.get_result(op, 0) + else + return jlir_value + end +end + +function emit_value(b::JLIRBuilder, loc::JLIR.Location, + value, ::Type) + type = typeof(value) + jlir_type = convert_type_to_jlirtype(b.ctx, type) + jlir_value = convert_value_to_jlirattr(b.ctx, value) + op = create!(b, ConstantOp(), loc, jlir_value, jlir_type) + return JLIR.get_result(op, 0) +end + +function emit_value(b::JLIRBuilder, loc::JLIR.Location, + value::QuoteNode, ::Type) + value = getfield(value, :value) + type = typeof(value) + jlir_type = convert_type_to_jlirtype(b.ctx, type) + jlir_value = convert_value_to_jlirattr(b.ctx, value) + op = create!(b, ConstantOp(), loc, jlir_value, jlir_type) + return JLIR.get_result(op, 0) +end + +function emit_value(b::JLIRBuilder, loc::JLIR.Location, + value::Core.Argument, type::Type) + idx = value.n + arg = JLIR.get_arg(b.blocks[1], idx - 1) + return arg +end + +function emit_value(b::JLIRBuilder, loc::JLIR.Location, + value::Core.SSAValue, type::Type) + @assert(value.id >= 1) + return getindex(b.values, value.id) +end + +function emit_value(b::JLIRBuilder, loc::JLIR.Location, + value::GlobalRef, type::Type) + name = value.name + v = getproperty(value.mod, value.name) + jlir_attr = convert_value_to_jlirattr(b.ctx, v) + jlir_type = convert_type_to_jlirtype(b.ctx, type) + op = create!(b, ConstantOp(), loc, jlir_attr, jlir_type) + return JLIR.get_result(op, 0) +end + +function emit_ftype(ctx::JLIR.Context, code::Core.Compiler.IRCode, ret_type::Type) + argtypes = getfield(code, :argtypes) + nargs = length(argtypes) + args = [convert_type_to_jlirtype(ctx, a) for a in argtypes] + ret = convert_type_to_jlirtype(ctx, ret_type) + jlir_func_type = get_functype(ctx, args, ret) + return jlir_func_type +end + +function handle_node!(b::JLIRBuilder, current::Int, + v::Vector{JLIR.Value}, stmt::Core.PhiNode, + type::Type, loc::JLIR.Location) + edges = stmt.edges + values = stmt.values + found = false + for (v, e) in zip(edges, values) + if e == current + val = emit_value(b, loc, v, Any) + push!(v, maybe_widen_type(b, loc, val, type)) + found = true + end + end + if !found + jlir_type = convert_type_to_jlirtype(b.ctx, type) + op = create!(b, UndefOp(), loc, jlir_type) + push!(v, JLIR.get_result(op, 0)) + end +end + +function walk_cfg_emit_branchargs(b::JLIRBuilder, current::Int, + target::Int, loc::JLIR.Location) + v = JLIR.Value[] + cfg = get_cfg(b) + for ind in cfg.blocks[target - 1].stmts + node = get_stmt(b, ind) + node isa Core.PhiNode || break + type = get_type(b, ind) + handle_node!(b, current, v, node, type, loc) + end + return v +end + +function process_stmt!(b::JLIRBuilder, ind::Int, + stmt::Nothing, loc::JLIR.Location, type::Type) + return false +end + +function process_stmt!(b::JLIRBuilder, ind::Int, + stmt, loc::JLIR.Location, type::Type) + setindex!(b.values, emit_value(b, loc, stmt, type), ind) + return false +end + +function process_stmt!(b::JLIRBuilder, ind::Int, + stmt::Core.GotoNode, loc::JLIR.Location, type::Type) + label = stmt.label + v = walk_cfg_emit_branchargs(b, b.insertion[], label, loc) + create!(b, GotoOp(), loc, b.blocks[label], v) + return true +end + +function process_stmt!(b::JLIRBuilder, ind::Int, + stmt::Core.GotoIfNot, loc::JLIR.Location, type::Type) + cond = emit_value(b, loc, stmt.cond, Any) + dest = stmt.dest + 1 # Accounts for entry block. + fallthrough = b.insertion[] + 1 + op = create!(b, GotoIfNotOp(), loc, + cond, b.blocks[dest], + walk_cfg_emit_branchargs(b, b.insertion[], + dest, loc), + b.blocks[fallthrough], + walk_cfg_emit_branchargs(b, b.insertion[], + fallthrough, loc)) + return true +end + +function process_stmt!(b::JLIRBuilder, ind::Int, + stmt::Core.PhiNode, loc::JLIR.Location, type::Type) + t = convert_type_to_jlirtype(b.ctx, type) + blk = get_insertion_block(b) + arg = ccall((:brutusBlockAddArgument, "libbrutus"), + JLIR.Value, + (JLIR.Block, JLIR.Type), + blk, t) + setindex!(b.values, arg, ind) + return false +end + +function process_stmt!(b::JLIRBuilder, ind::Int, + stmt::Core.PiNode, loc::JLIR.Location, type::Type) + val = stmt.val + @assert(type == stmt.type) + jlir_type = convert_type_to_jlirtype(b.ctx, type) + op = create!(b, PiOp(), loc, + emit_value(b, loc, val, Any), jlir_type) + setindex!(b.values, JLIR.get_result(op, 0), ind) + return false +end + +function process_stmt!(b::JLIRBuilder, ind::Int, + stmt::Core.ReturnNode, loc::JLIR.Location, type::Type) + if isdefined(stmt, :val) + jlir_v = emit_value(b, loc, stmt.val, Any) + value = maybe_widen_type(b, loc, jlir_v, b.rt) + else + jlir_type = convert_type_to_jlirtype(b.ctx, type) + value = create!(b, UndefOp(), loc, jlir_type) + end + create!(b, ReturnOp(), loc, value) + return true +end + +function process_stmt!(b::JLIRBuilder, ind::Int, + expr::Expr, loc::JLIR.Location, type::Type) + head = expr.head + args = expr.args + jlir_type = convert_type_to_jlirtype(b.ctx, type) + if head == :invoke + @assert(args[1] isa Core.MethodInstance) + mi = args[1] + callee = emit_value(b, loc, args[2], Any) + args = JLIR.Value[emit_value(b, loc, a, Any) + for a in args[2 : end]] + op = create!(b, InvokeOp, loc, mi, callee, args, jlir_type) + elseif head == :call + callee = emit_value(b, loc, args[1], Any) + args = JLIR.Value[emit_value(b, loc, a, Any) + for a in args[2 : end]] + op = create!(b, CallOp(), loc, callee, args, jlir_type) + else + op = create!(b, UnimplementedOp(), loc, jlir_type) + end + res = JLIR.get_result(op, 0) + setindex!(b.values, res, ind) + return false +end + +##### +##### JLIR generation +##### + +mutable struct CompiledJLIRModule + ctx::JLIR.Context + mod::JLIR.Module + name::String +end + +function cleanup!(jlir::CompiledJLIRModule) + JLIR.destroy!(jlir.ctx) + JLIR.destroy!(jlir.mod) +end + +Base.display(jlir::CompiledJLIRModule) = JLIR.dump(JLIR.get_operation(jlir.mod)) + +function codegen_jlir(ir_code::Core.Compiler.IRCode, rt::Type, name::String) + # Create builder. + b = JLIRBuilder(ir_code, rt, name) + m = JLIR.Module(JLIR.Location(b.ctx)) + + # Create branch from entry block. + v = walk_cfg_emit_branchargs(b, 1, 2, b.locations[1]) + goto = create_goto_op(JLIR.Location(b.ctx), b.blocks[2], v) + push!(b.blocks[1], goto) + + # Process. + location_indices = get_locindices(b) + stmts = get_stmts(b) + types = get_types(b) + for (ind, (stmt, type)) in enumerate(zip(stmts, types)) + lt_ind = location_indices[ind] + loc = lt_ind == 0 ? JLIR.Location() : b.locations[lt_ind] + is_terminator = false + is_terminator = process_stmt!(b, ind, stmt, loc, type) + if is_terminator + b.insertion[] += 1 + end + end + + # Create op from module and verify. + JLIR.push_operation!(m, finish(b)) + @assert(JLIR.verify(JLIR.get_operation(m))) + return CompiledJLIRModule(b.ctx, m, name) +end + +function canonicalize!(jlir::CompiledJLIRModule) + ccall((:brutus_canonicalize, "libbrutus"), + Cvoid, + (JLIR.Context, JLIR.Module), + jlir.ctx, jlir.mod) + op = JLIR.get_operation(jlir.mod) + @assert(JLIR.verify(op)) + return +end + +function dialect_lower_to_std!(jlir::CompiledJLIRModule) + ccall((:brutus_lower_to_standard, "libbrutus"), + Cvoid, + (JLIR.Context, JLIR.Module), + jlir.ctx, jlir.mod) + op = JLIR.get_operation(jlir.mod) + @assert(JLIR.verify(op)) + return +end + +function dialect_lower_to_llvm!(jlir::CompiledJLIRModule) + ccall((:brutus_lower_to_llvm, "libbrutus"), + Cvoid, + (JLIR.Context, JLIR.Module), + jlir.ctx, jlir.mod) + op = JLIR.get_operation(jlir.mod) + @assert(JLIR.verify(op)) + return +end + +function thunk(jlir::CompiledJLIRModule) + fptr = ccall((:c_brutus_create_execution_engine, "libbrutus"), + Ptr{Nothing}, + (JLIR.Context, JLIR.Module, Cstring), + jlir.ctx, jlir.mod, jlir.name) + return fptr +end diff --git a/Brutus/src/compiler/jlirgen.jl b/Brutus/src/compiler/jlirgen.jl new file mode 100644 index 0000000..d9350f6 --- /dev/null +++ b/Brutus/src/compiler/jlirgen.jl @@ -0,0 +1,178 @@ +# TODO: In future, should be autogenerated from tablegen. + +function create_unimplemented_op(loc::JLIR.Location, type::JLIR.Type) + state = JLIR.create_operation_state("jlir.unimplemented", loc) + JLIR.push_results!(state, type) + return JLIR.Operation(state) +end + +function create_undef_op(loc::JLIR.Location, type::JLIR.Type) + state = JLIR.create_operation_state("jlir.undef", loc) + JLIR.push_results!(state, type) + return JLIR.Operation(state) +end + +function create_constant_op(loc::JLIR.Location, named_attr::JLIR.NamedAttribute, + type::JLIR.Type) + state = JLIR.create_operation_state("jlir.constant", loc) + JLIR.push_attributes!(state, named_attr) + JLIR.push_results!(state, type) + return JLIR.Operation(state) +end + +function create_goto_op(loc::JLIR.Location, to::JLIR.Block, + v::Vector{JLIR.Value}) + state = JLIR.create_operation_state("jlir.goto", loc) + JLIR.push_operands!(state, v) + JLIR.push_successors!(state, to) + return JLIR.Operation(state) +end + +function create_gotoifnot_op(loc::JLIR.Location, cond::JLIR.Value, + dest::JLIR.Block, v::Vector{JLIR.Value}, + fall::JLIR.Block, fallv::Vector{JLIR.Value}) + state = JLIR.create_operation_state("jlir.gotoifnot", loc) + JLIR.push_operands!(state, cond) + JLIR.push_operands!(state, v) + JLIR.push_operands!(state, fallv) + JLIR.push_successors!(state, dest) + JLIR.push_successors!(state, fall) + return JLIR.Operation(state) +end + +function create_pi_op(loc::JLIR.Location, value::JLIR.Value, + type::JLIR.Type) + state = JLIR.create_operation_state("jlir.pi", loc) + JLIR.push_results!(state, type) + JLIR.push_operands!(state, value) + return JLIR.Operation(state) +end + +function create_return_op(loc::JLIR.Location, input::JLIR.Value) + state = JLIR.create_operation_state("jlir.return", loc) + JLIR.push_operands!(state, input) + return JLIR.Operation(state) +end + +function create_call_op(loc::JLIR.Location, callee::JLIR.Value, + arguments::Vector{JLIR.Value}, type::JLIR.Type) + state = JLIR.create_operation_state("jlir.call", loc) + operands = JLIR.Value[callee, arguments...] + JLIR.push_operands!(state, operands) + JLIR.push_results!(state, type) + return JLIR.Operation(state) +end + +function create_invoke_op(loc::JLIR.Location, mi::JLIR.Value, + callee::JLIR.Value, arguments::Vector{JLIR.Value}, type::JLIR.Type) + state = JLIR.create_operation_state("jlir.invoke", loc) + JLIR.push_operands!(state, JLIR.Value[mi, callee, arguments...]) + JLIR.push_results!(state, type) + return JLIR.Operation(state) +end + +##### +##### High-level version of create +##### + +struct UnimplementedOp end +struct UndefOp end +struct ConstantOp end +struct GotoOp end +struct GotoIfNotOp end +struct PiOp end +struct ReturnOp end +struct CallOp end +struct InvokeOp end + +function create!(b::JLIRBuilder, ::UnimplementedOp, + loc::JLIR.Location, type::JLIR.Type) + @assert(isdefined(b, :blocks)) + op = create_unimplemented_op(loc, type) + @assert(JLIR.verify(op)) + blk = get_insertion_block(b) + JLIR.push_operation!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::UndefOp, + loc::JLIR.Location, type::JLIR.Type) + @assert(isdefined(b, :blocks)) + op = create_undef_op(loc, type) + @assert(JLIR.verify(op)) + blk = get_insertion_block(b) + JLIR.push_operation!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::ConstantOp, + loc::JLIR.Location, value::JLIR.Attribute, type::JLIR.Type) + @assert(isdefined(b, :blocks)) + named_attr = JLIR.NamedAttribute(b.ctx, "value", value) + op = create_constant_op(loc, named_attr, type) + @assert(JLIR.verify(op)) + blk = get_insertion_block(b) + JLIR.push_operation!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::GotoOp, + loc::JLIR.Location, to::JLIR.Block, v::Vector{JLIR.Value}) + @assert(isdefined(b, :blocks)) + from = get_insertion_block(b) + op = create_goto_op(loc, to, v) + JLIR.push_operation!(from, op) + return op +end + +function create!(b::JLIRBuilder, ::GotoIfNotOp, + loc::JLIR.Location, cond::JLIR.Value, dest::JLIR.Block, + v::Vector{JLIR.Value}, fall::JLIR.Block, fallv::Vector{JLIR.Value}) + @assert(isdefined(b, :blocks)) + op = create_gotoifnot_op(loc, cond, dest, v, fall, fallv) + blk = get_insertion_block(b) + JLIR.push_operation!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::PiOp, + loc::JLIR.Location, value::JLIR.Value, type::JLIR.Type) + @assert(isdefined(b, :blocks)) + op = create_pi_op(loc, value, type) + @assert(JLIR.verify(op)) + blk = get_insertion_block(b) + JLIR.push_operation!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::ReturnOp, + loc::JLIR.Location, input::JLIR.Value) + @assert(isdefined(b, :blocks)) + op = create_return_op(loc, input) + blk = get_insertion_block(b) + JLIR.push_operation!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::CallOp, + loc::JLIR.Location, callee::JLIR.Value, arguments::Vector{JLIR.Value}, + type::JLIR.Type) + @assert(isdefined(b, :blocks)) + op = create_call_op(loc, callee, arguments, type) + @assert(JLIR.verify(op)) + blk = get_insertion_block(b) + JLIR.push_operation!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::InvokeOp, + loc::JLIR.Location, mi::Core.MethodInstance, callee::JLIR.Value, + arguments::Vector{JLIR.Value}, type::JLIR.Type) + @assert(isdefined(b, :blocks)) + jlir_mi = convert_value_to_jlirattr(b, mi) + op = create_invoke_op(loc, jlir_mi, callee, arguments, type) + @assert(JLIR.verify(op)) + blk = get_insertion_block(b) + JLIR.push_operation!(blk, op) + return op +end diff --git a/Brutus/src/compiler/opbuilder.jl b/Brutus/src/compiler/opbuilder.jl new file mode 100644 index 0000000..bb0af23 --- /dev/null +++ b/Brutus/src/compiler/opbuilder.jl @@ -0,0 +1,146 @@ +##### +##### Builder +##### + +# High-level version of MLIR's OpBuilder. + +struct JLIRBuilder + ctx::JLIR.Context + insertion::Ref{Int} + values::Dict{Int, JLIR.Value} + arguments::Vector{JLIR.Type} + locations::Vector{JLIR.Location} + blocks::Vector{JLIR.Block} + reg::JLIR.Region + code::Core.Compiler.IRCode + rt::Type + state::JLIR.OperationState +end + +function JLIRBuilder(code::Core.Compiler.IRCode, rt::Type, name::String) + + # Create a context and register dialects required by Brutus. + ctx = JLIR.create_context() + ccall((:brutus_register_dialects, "libbrutus"), + Cvoid, + (JLIR.Context, ), + ctx) + + # IRCode metadata -> JLIR metadata (locations). + irstream = code.stmts + stmts = irstream.inst + types = irstream.type + location_indices = getfield(irstream, :line) + linetable = getfield(code, :linetable) + locations = extract_linetable_locations(ctx, linetable) + + # Create toplevel FuncOp. + argtypes = getfield(code, :argtypes) + args = [convert_type_to_jlirtype(ctx, a) for a in argtypes] + ftype = emit_ftype(ctx, code, rt) + state = JLIR.create_operation_state("func", locations[1]) + type_attr = JLIR.get_type_attribute(ftype) + named_type_attr = JLIR.NamedAttribute(ctx, "type", type_attr) + string_attr = JLIR.get_string_attribute(ctx, name) + symbol_name_attr = JLIR.NamedAttribute(ctx, "sym_name", string_attr) + viz_attr = JLIR.get_string_attribute(ctx, "nested") + named_viz_attr = JLIR.NamedAttribute(ctx, "sym_visibility", viz_attr) + unit_attr = JLIR.get_unit_attribute(ctx) + JLIR.push_attributes!(state, named_type_attr) + JLIR.push_attributes!(state, symbol_name_attr) + JLIR.push_attributes!(state, named_viz_attr) + JLIR.push_attributes!(state, JLIR.NamedAttribute(ctx, "llvm.emit_c_interface", unit_attr)) + entry_blk, reg = JLIR.add_entry_block!(state, args) + tr = JLIR.get_first_block(reg) + nblocks = length(code.cfg.blocks) + blocks = JLIR.Block[entry_blk] + for i in 1 : nblocks + blk = JLIR.Block() + JLIR.push!(reg, blk) + push!(blocks, blk) + end + + # Pass FuncOp state in builder. + return JLIRBuilder(ctx, Ref(2), Dict{Int, JLIR.Value}(), args, locations, blocks, reg, code, rt, state) +end + +set_insertion!(b::JLIRBuilder, blk::Int) = b.insertion[] = blk +get_insertion_block(b::JLIRBuilder) = b.blocks[b.insertion[]] + +get_locindices(b::JLIRBuilder) = b.code.stmts.line +get_stmts(b::JLIRBuilder) = b.code.stmts.inst +get_types(b::JLIRBuilder) = b.code.stmts.type +get_stmt(b::JLIRBuilder, ind::Int) = getindex(b.code.stmts.inst, ind) +get_type(b::JLIRBuilder, ind::Int) = getindex(b.code.stmts.type, ind) +get_cfg(b::JLIRBuilder) = b.code.cfg + +function push!(b::JLIRBuilder, op::JLIR.Operation) + blk = b.blocks[b.insertion] + push_operation!(blk, op) +end + +finish(b::JLIRBuilder) = JLIR.Operation(b.state) + +##### +##### Utilities +##### + +# Explicitly exposed as part of extern C in codegen.cpp. + +function convert_type_to_jlirtype(ctx::JLIR.Context, a) + return ccall((:brutus_get_jlirtype, "libbrutus"), + JLIR.Type, + (JLIR.Context, Any), + ctx, a) +end + +function convert_value_to_jlirattr(ctx::JLIR.Context, a) + return ccall((:brutus_get_jlirattr, "libbrutus"), + JLIR.Attribute, + (JLIR.Context, Any), + ctx, a) +end + +function convert_jlirtype_to_type(v::JLIR.Type) + return ccall((:brutus_get_julia_type, "libbrutus"), + Type, + (JLIR.Type, ), + v) +end + +function get_functype(ctx::JLIR.Context, args::Vector{JLIR.Type}, ret::JLIR.Type) + return MLIR.API.mlirFunctionTypeGet(ctx, length(args), args, 1, [ret]) +end + +function get_functype(ctx::JLIR.Context, args, ret) + return get_functype(ctx, length(args), map(args) do a + convert_type_to_jlirtype(ctx, a) + end, 1, [convert_type_to_jlirtype(ctx, ret)]) +end + +function unwrap(mi::Core.MethodInstance) + return mi.def.value +end +unwrap(s) = s + +function extract_linetable_locations(ctx::JLIR.Context, v::Vector{Core.LineInfoNode}) + locations = JLIR.Location[] + for n in v + method = unwrap(n.method) + file = String(n.file) + line = n.line + inlined_at = n.inlined_at + if method isa Method + fname = String(method.name) + end + if method isa Symbol + fname = String(method) + end + current = JLIR.Location(ctx, fname, UInt32(line), UInt32(0)) # TODO: col. + if inlined_at > 1 + current = JLIR.Location(current, locations[inlined_at - 1]) + end + push!(locations, current) + end + return locations +end diff --git a/Brutus/src/interface.jl b/Brutus/src/interface.jl index 4fe859b..4ba9194 100644 --- a/Brutus/src/interface.jl +++ b/Brutus/src/interface.jl @@ -1,3 +1,141 @@ +##### +##### GPUCompiler codegen +##### + +struct BrutusCompilerTarget <: AbstractCompilerTarget end +GPUCompiler.llvm_triple(::BrutusCompilerTarget) = Sys.MACHINE +GPUCompiler.llvm_machine(::BrutusCompilerTarget) = tm[] + +module Runtime + # the runtime library + signal_exception() = return + malloc(sz) = Base.Libc.malloc(sz) + report_oom(sz) = return + report_exception(ex) = return + report_exception_name(ex) = return + report_exception_frame(idx, func, file, line) = return +end + +@enum DumpOption::UInt8 begin + DumpIRCode = 0 + DumpTranslated = 1 + DumpCanonicalized = 2 + DumpLoweredToStd = 4 + DumpLoweredToLLVM = 8 + DumpTranslateToLLVM = 16 +end + +const DumpAll = DumpOption[DumpIRCode, + DumpTranslated, + DumpCanonicalized, + DumpLoweredToStd, + DumpLoweredToLLVM] + +struct BrutusCompilerParams <: AbstractCompilerParams + emit_fptr::Bool + dump_options::Vector{DumpOption} +end + +GPUCompiler.ci_cache(job::CompilerJob{BrutusCompilerTarget}) = GLOBAL_CI_CACHE +GPUCompiler.runtime_module(job::CompilerJob{BrutusCompilerTarget}) = Runtime +GPUCompiler.isintrinsic(::CompilerJob{BrutusCompilerTarget}, fn::String) = true +GPUCompiler.can_throw(::CompilerJob{BrutusCompilerTarget}) = true +GPUCompiler.runtime_slug(job::CompilerJob{BrutusCompilerTarget}) = "brutus" + +function find_invokes(IR) + callees = Core.MethodInstance[] + for stmt in IR.stmts + if stmt isa Expr + if stmt.head == :invoke + mi = stmt.args[1] + push!(callees, mi) + end + end + end + return callees +end + +# Emit MLIR IR to stdout +function emit(job::CompilerJob) + ft = typeof(job.source.f) + tt = job.source.tt + emit_fptr = job.params.emit_fptr + dump_options = job.params.dump_options + name = (ft <: Function) ? nameof(ft.instance) : nameof(ft) + + # get first method instance matching signature + entry_mi = get_methodinstance(Tuple{ft, tt.parameters...}) + IR, rt = code_ircode(entry_mi) + + if DumpIRCode in dump_options + println("return type: ", rt) + println("IRCode:\n") + println(IR) + println() + end + + #worklist = [IR] + #methods = Dict{Core.MethodInstance, Tuple{Core.Compiler.IRCode, Any}}( + # entry_mi => (IR, rt) + #) + + #while !isempty(worklist) + # code = pop!(worklist) + # callees = find_invokes(code) + # for callee in callees + # if !haskey(methods, callee) + # _code, _rt = code_ircode(callee) + + # methods[callee] = (_code, _rt) + # push!(worklist, _code) + # end + # end + #end + + # generate LLVM bitcode and load it + jlir = Brutus.Compiler.codegen_jlir(IR, rt, String(name)) + if DumpTranslated in dump_options + println("JLIR:") + display(jlir) + println() + end + + Brutus.Compiler.canonicalize!(jlir) + if DumpCanonicalized in dump_options + println("After canonicalization:") + display(jlir) + println() + end + + Brutus.Compiler.dialect_lower_to_std!(jlir) + if DumpLoweredToStd in dump_options + println("Standard:") + display(jlir) + println() + end + + Brutus.Compiler.dialect_lower_to_llvm!(jlir) + if DumpLoweredToLLVM in dump_options + println("LLVM dialect:") + display(jlir) + println() + end + + fptr = Brutus.Compiler.thunk(jlir) + Brutus.Compiler.cleanup!(jlir) + return (fptr, rt) +end + +function emit(@nospecialize(ft), @nospecialize(tt); + emit_fptr::Bool=true, + dump_options::Vector{DumpOption}=DumpOption[]) + fspec = GPUCompiler.FunctionSpec(ft, Tuple{tt...}, false, nothing) + target = BrutusCompilerTarget() + params = BrutusCompilerParams(emit_fptr, dump_options) + job = CompilerJob(target, fspec, params) + return emit(job) +end + ##### ##### Call Interface ##### @@ -29,17 +167,18 @@ struct Thunk{F, RT, TT} ptr::Ptr{Cvoid} end -const brutus_cache = Dict{UInt,Any}() - function link(job::CompilerJob, (fptr, rt)) @assert fptr != C_NULL - fptr, rt = result f = job.source.f tt = job.source.tt return Thunk{typeof(f), rt, tt}(f, fptr) end -function thunk(f::F, tt::TT=Tuple{}; emit_fptr::Bool = true, dump_options::Vector{DumpOption} = DumpOption[]) where {F<:Base.Callable, TT<:Type} +const brutus_cache = Dict{UInt,Any}() + +function thunk(f::F, tt::TT=Tuple{}; + emit_fptr::Bool = true, + dump_options::Vector{DumpOption} = DumpOption[]) where {F <: Base.Callable, TT <: Type} fspec = GPUCompiler.FunctionSpec(f, tt, false, nothing) target = BrutusCompilerTarget() params = BrutusCompilerParams(emit_fptr, dump_options) @@ -47,7 +186,7 @@ function thunk(f::F, tt::TT=Tuple{}; emit_fptr::Bool = true, dump_options::Vecto return GPUCompiler.cached_compilation(brutus_cache, job, emit, link) end -# Need to pass struct as pointer, to match cifacme ABI +# Need to pass struct as pointer, to match ciface ABI abi(::Type{<:Array{T, N}}) where {T, N} = Ref{MemrefDescriptor{T, N}} function abi(T::DataType) if isprimitivetype(T) @@ -72,7 +211,8 @@ end return expr end -function call(f::F, args...) where F +function call(f::F, args...; + dump_options::Vector{DumpOption} = DumpOption[]) where F TT = Tuple{map(Core.Typeof, args)...} - return thunk(f, TT)(args...) + return thunk(f, TT; dump_options = dump_options)(args...) end diff --git a/Brutus/src/reflection.jl b/Brutus/src/reflection.jl index 17870df..b311993 100644 --- a/Brutus/src/reflection.jl +++ b/Brutus/src/reflection.jl @@ -4,7 +4,6 @@ function get_methodinstance(@nospecialize(sig); ms = Base._methods_by_ftype(sig, 1, Base.get_world_counter()) @assert length(ms) == 1 m = ms[1] - display(m) mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, (Any, Any, Any), m[3], m[1], m[2]) diff --git a/Brutus/test/runtests.jl b/Brutus/test/runtests.jl index e89665e..b571c37 100644 --- a/Brutus/test/runtests.jl +++ b/Brutus/test/runtests.jl @@ -68,4 +68,5 @@ for array in [rand(Int64, 2, 3), rand(Int64, 2, 3)] @test Brutus.call(customsum, array) == customsum(array) end end + # TODO: arrays with floating point elements diff --git a/include/brutus/Dialect/Julia/JuliaOps.h b/include/brutus/Dialect/Julia/JuliaOps.h index fa52ea1..878abf3 100644 --- a/include/brutus/Dialect/Julia/JuliaOps.h +++ b/include/brutus/Dialect/Julia/JuliaOps.h @@ -1,6 +1,8 @@ #ifndef JL_DIALECT_JLIR_H #define JL_DIALECT_JLIR_H +#include + #include "mlir/IR/Dialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" diff --git a/include/brutus/Dialect/Julia/JuliaOps.td b/include/brutus/Dialect/Julia/JuliaOps.td index 1690d99..e034351 100644 --- a/include/brutus/Dialect/Julia/JuliaOps.td +++ b/include/brutus/Dialect/Julia/JuliaOps.td @@ -380,4 +380,4 @@ def JLIR_Builtin_ifelse : JLIR_IntrinsicBuiltinOp<"ifelse">; def JLIR_Builtin__typevar : JLIR_IntrinsicBuiltinOp<"_typevar">; // invoke_kwsorter? -#endif // JULIA_MLIR_JLIR_TD \ No newline at end of file +#endif // JULIA_MLIR_JLIR_TD diff --git a/include/brutus/brutus.h b/include/brutus/brutus.h index d280adf..7c9b1f4 100644 --- a/include/brutus/brutus.h +++ b/include/brutus/brutus.h @@ -24,15 +24,25 @@ extern "C" { #endif - + void brutus_register_dialects(MlirContext context); + MlirType brutus_get_jlirtype(MlirContext context, jl_datatype_t *datatype); + jl_datatype_t *brutus_get_julia_type(MlirType v); + MlirAttribute brutus_get_jlirattr(MlirContext context, jl_value_t *value); + + // TODO: deprecate -- should be available in MLIR C API. + void brutus_register_extern_dialect(MlirContext context, MlirDialect dialect); + MlirValue brutusBlockAddArgument(MlirBlock block, MlirType type); + + // Export C API for pipeline. typedef void (*ExecutionEngineFPtrResult)(void **); void brutus_init(jl_module_t *brutus); - void brutus_codegen_jlir(MlirContext context, MlirModule module, jl_value_t *methods, jl_method_instance_t *entry_mi, char dump_flags); - void brutus_canonicalize(MlirContext context, MlirModule module, char dump_flags); - void brutus_lower_to_standard(MlirContext context, MlirModule module, char dump_flags); - void brutus_lower_to_llvm(MlirContext context, MlirModule module, char dump_flags); + void brutus_codegen_jlir(MlirContext context, MlirModule module, jl_value_t *methods, jl_method_instance_t *entry_mi); + void brutus_canonicalize(MlirContext context, MlirModule module); + void brutus_lower_to_standard(MlirContext context, MlirModule module); + void brutus_lower_to_llvm(MlirContext context, MlirModule module); ExecutionEngineFPtrResult brutus_create_execution_engine(MlirContext context, MlirModule module, std::string name); + ExecutionEngineFPtrResult c_brutus_create_execution_engine(MlirContext context, MlirModule module, const char *name); ExecutionEngineFPtrResult brutus_codegen(jl_value_t *methods, jl_method_instance_t *entry_mi, char emit_fptr, char dump_flags); #ifdef __cplusplus diff --git a/lib/Codegen/Codegen.cpp b/lib/Codegen/Codegen.cpp index fd15665..a8f073c 100644 --- a/lib/Codegen/Codegen.cpp +++ b/lib/Codegen/Codegen.cpp @@ -1,4 +1,3 @@ - #include "brutus/brutus.h" #include "brutus/brutus_internal.h" #include "brutus/Dialect/Julia/JuliaOps.h" @@ -17,6 +16,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/Passes.h" #include "mlir/Target/LLVMIR.h" +#include "llvm-c/Core.h" #include "mlir-c/IR.h" #include "mlir/CAPI/Wrap.h" #include "mlir/CAPI/IR.h" @@ -469,32 +469,58 @@ mlir::FuncOp emit_function(jl_mlirctx_t &ctx, extern "C" { - enum DumpOption + // TODO: deprecate -- available in MLIR C API. + void brutus_register_extern_dialect(MlirContext Context, MlirDialect Dialect) { - // DUMP_IRCODE = 0, - DUMP_TRANSLATED = 1, - DUMP_CANONICALIZED = 2, - DUMP_LOWERED_TO_STD = 4, - DUMP_LOWERED_TO_LLVM = 8, - DUMP_TRANSLATE_TO_LLVM = 16, + return; + } + + void brutus_register_dialects(MlirContext Context) + { + mlir::MLIRContext *ctx = unwrap(Context); + ctx->getOrLoadDialect(); + ctx->getOrLoadDialect(); + ctx->getOrLoadDialect(); + }; + + MlirType brutus_get_jlirtype(MlirContext Context, + jl_datatype_t *datatype) + { + mlir::MLIRContext *ctx = unwrap(Context); + mlir::Type type = JuliaType::get(ctx, datatype); + return wrap(type); + }; + + jl_datatype_t *brutus_get_julia_type(MlirType v) { + mlir::Type type = unwrap(v); + return (jl_datatype_t *)type.cast().getDatatype(); + } + + MlirAttribute brutus_get_jlirattr(MlirContext Context, + jl_value_t *value) + { + mlir::MLIRContext *ctx = unwrap(Context); + mlir::Attribute val = JuliaValueAttr::get(ctx, value); + return wrap(val); }; - // TODO: enum with ERROR codes for failures. + // TODO: deprecate -- available in MLIR C API. + MlirValue brutusBlockAddArgument(MlirBlock block, MlirType type) + { + return wrap(unwrap(block)->addArgument(unwrap(type))); + } + void brutus_codegen_jlir(MlirContext Context, MlirModule Module, jl_value_t *methods, - jl_method_instance_t *entry_mi, - char dump_flags) + jl_method_instance_t *entry_mi) { - mlir::MLIRContext *context = unwrap(Context); mlir::ModuleOp module = unwrap(Module); + brutus_register_dialects(Context); + mlir::MLIRContext *context = unwrap(Context); jl_mlirctx_t ctx(context); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - jl_value_t *entry = jl_call2(getindex_func, methods, (jl_value_t *)entry_mi); jl_value_t *ir_code = jl_fieldref(entry, 0); jl_value_t *ret_type = jl_fieldref(entry, 1); @@ -517,8 +543,7 @@ extern "C" // canonicalize void brutus_canonicalize(MlirContext Context, - MlirModule Module, - char dump_flags) + MlirModule Module) { mlir::MLIRContext *context = unwrap(Context); mlir::ModuleOp module = unwrap(Module); @@ -542,8 +567,7 @@ extern "C" // lower to Standard dialect void brutus_lower_to_standard(MlirContext Context, - MlirModule Module, - char dump_flags) + MlirModule Module) { mlir::MLIRContext *context = unwrap(Context); mlir::ModuleOp module = unwrap(Module); @@ -563,8 +587,7 @@ extern "C" // lower to LLVM dialect void brutus_lower_to_llvm(MlirContext Context, - MlirModule Module, - char dump_flags) + MlirModule Module) { mlir::MLIRContext *context = unwrap(Context); mlir::ModuleOp module = unwrap(Module); @@ -629,12 +652,22 @@ extern "C" return expectedFPtr.get(); } + enum DumpOption + { + // DUMP_IRCODE = 0, + DUMP_TRANSLATED = 1, + DUMP_CANONICALIZED = 2, + DUMP_LOWERED_TO_STD = 4, + DUMP_LOWERED_TO_LLVM = 8, + DUMP_TRANSLATE_TO_LLVM = 16, + }; + ExecutionEngineFPtrResult brutus_codegen(jl_value_t *methods, jl_method_instance_t *entry_mi, char emit_fptr, char dump_flags) { MlirContext Context = mlirContextCreate(); MlirModule Module = mlirModuleCreateEmpty(mlirLocationUnknownGet(Context)); - brutus_codegen_jlir(Context, Module, methods, entry_mi, dump_flags); + brutus_codegen_jlir(Context, Module, methods, entry_mi); if (dump_flags && DUMP_TRANSLATED) { mlir::ModuleOp module = unwrap(Module); @@ -643,7 +676,7 @@ extern "C" llvm::dbgs() << "\n\n"; } - brutus_canonicalize(Context, Module, dump_flags); + brutus_canonicalize(Context, Module); if (dump_flags & DUMP_CANONICALIZED) { mlir::ModuleOp module = unwrap(Module); @@ -652,7 +685,7 @@ extern "C" llvm::dbgs() << "\n\n"; } - brutus_lower_to_standard(Context, Module, dump_flags); + brutus_lower_to_standard(Context, Module); if (dump_flags & DUMP_LOWERED_TO_STD) { mlir::ModuleOp module = unwrap(Module); @@ -661,7 +694,7 @@ extern "C" llvm::dbgs() << "\n\n"; } - brutus_lower_to_llvm(Context, Module, dump_flags); + brutus_lower_to_llvm(Context, Module); if (dump_flags & DUMP_LOWERED_TO_LLVM) { mlir::ModuleOp module = unwrap(Module); @@ -693,4 +726,12 @@ extern "C" return engine_ptr; } + + ExecutionEngineFPtrResult c_brutus_create_execution_engine(MlirContext Context, + MlirModule Module, + const char *name) + { + std::string str(name); + return brutus_create_execution_engine(Context, Module, str); + } }