Skip to content
Open
10 changes: 5 additions & 5 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1106,8 +1106,8 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
wrapftype = MLIR.IR.Type(
MLIR.API.mlirLLVMFunctionTypeGet(voidty, length(wrapper_tys), wrapper_tys, false)
)
wrapfunc = MLIR.IR.with_block(MLIR.IR.body(mod)) do
return MLIR.Dialects.llvm.func(;
wrapfunc = MLIR.IR.@scope MLIR.IR.body(mod) begin
MLIR.Dialects.llvm.func(;
sym_name,
sym_visibility=MLIR.IR.Attribute("private"),
function_type=wrapftype,
Expand Down Expand Up @@ -1153,7 +1153,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
end

# TODO(#2240): check for only integer and explicitly non cutraced types
MLIR.IR.with_block(wrapbody) do
MLIR.IR.@scope wrapbody begin
argty = MLIR.IR.Type(
MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx - 1)
)
Expand Down Expand Up @@ -1224,7 +1224,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
julia_arg = allargs[p[2]]

offset = get_field_offset(typeof(julia_arg), p[3:end])
MLIR.IR.with_block(wrapbody) do
MLIR.IR.@scope wrapbody begin
ptr = MLIR.IR.result(
MLIR.Dialects.llvm.getelementptr(
alloc,
Expand All @@ -1241,7 +1241,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
argidx += 1
end

MLIR.IR.with_block(wrapbody) do
MLIR.IR.@scope wrapbody begin
for arg in allocs
if arg === nothing
continue
Expand Down
19 changes: 9 additions & 10 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1542,11 +1542,13 @@ end

# helper for debug purposes: String -> Text
function run_pass_pipeline_on_source(source, pass_pipeline; enable_verifier=true)
return MLIR.IR.with_context() do _
mod = parse(MLIR.IR.Module, source)
run_pass_pipeline!(mod, pass_pipeline; enable_verifier)
MLIR.IR.verifyall(MLIR.IR.Operation(mod); debug=true)
Text(repr(mod))
MLIR.IR.@dispose ctx = Reactant.ReactantContext() begin
MLIR.IR.@scope ctx begin
mod = parse(MLIR.IR.Module, source)
run_pass_pipeline!(mod, pass_pipeline; enable_verifier)
MLIR.IR.verifyall(MLIR.IR.Operation(mod); debug=true)
Text(repr(mod))
end
end
end

Expand Down Expand Up @@ -1833,9 +1835,6 @@ function compile_mlir!(
@assert MLIR.IR.current_context() == MLIR.IR.context(mod)
client = client !== nothing ? client : XLA.default_backend()

# Explicitly don't use with_block to avoid creating a closure, which creates
# both compile-time and relocatability issues

MLIR.IR.activate(mod)
MLIR.IR.activate(MLIR.IR.body(mod))
activate_callcache!(callcache)
Expand Down Expand Up @@ -2641,8 +2640,8 @@ function compile_mlir!(

MLIR.IR.dispose(ret)

MLIR.IR.with_block(fnbody) do
return MLIR.Dialects.func.return_(nresults)
MLIR.IR.@scope fnbody begin
MLIR.Dialects.func.return_(nresults)
end

out_tys2 = [MLIR.IR.type(a) for a in nresults]
Expand Down
32 changes: 14 additions & 18 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2381,8 +2381,8 @@ end

# compile the true branch without any returns first
true_fn_mod = MLIR.IR.current_module()
true_func_tmp = MLIR.IR.with_block(MLIR.IR.body(true_fn_mod)) do
return MLIR.Dialects.func.func_(;
true_func_tmp = MLIR.IR.@scope MLIR.IR.body(true_fn_mod) begin
MLIR.Dialects.func.func_(;
sym_name=string(true_fn) * "_tb_tmp",
function_type=MLIR.IR.FunctionType(input_types, []),
body=MLIR.IR.Region(),
Expand Down Expand Up @@ -2447,8 +2447,8 @@ end

# compile the false branch without any returns similar to the true branch
false_fn_mod = MLIR.IR.current_module()
false_func_tmp = MLIR.IR.with_block(MLIR.IR.body(false_fn_mod)) do
return MLIR.Dialects.func.func_(;
false_func_tmp = MLIR.IR.@scope MLIR.IR.body(false_fn_mod) begin
MLIR.Dialects.func.func_(;
sym_name=string(false_fn) * "_fb_tmp",
function_type=MLIR.IR.FunctionType(input_types, []),
body=MLIR.IR.Region(),
Expand Down Expand Up @@ -2675,8 +2675,8 @@ end
# With the corrected results, we can compile the true and false branches
tb_out_types = [mlir_type(tr) for tr in tb_corrected_linear_results]

true_fn_compiled = MLIR.IR.with_block(MLIR.IR.body(true_fn_mod)) do
return MLIR.Dialects.func.func_(;
true_fn_compiled = MLIR.IR.@scope MLIR.IR.body(true_fn_mod) begin
MLIR.Dialects.func.func_(;
sym_name=Reactant.TracedUtils.__lookup_unique_name_in_module(
true_fn_mod, string(true_fn) * "_tb"
),
Expand All @@ -2692,8 +2692,8 @@ end

fb_out_types = [mlir_type(fr) for fr in fb_corrected_linear_results]

false_fn_compiled = MLIR.IR.with_block(MLIR.IR.body(false_fn_mod)) do
return MLIR.Dialects.func.func_(;
false_fn_compiled = MLIR.IR.@scope MLIR.IR.body(false_fn_mod) begin
MLIR.Dialects.func.func_(;
sym_name=Reactant.TracedUtils.__lookup_unique_name_in_module(
false_fn_mod, string(false_fn) * "_fb"
),
Expand Down Expand Up @@ -2846,8 +2846,8 @@ result = Ops.case(
branch_results = Vector{Any}(undef, n_branches)

for b in 1:n_branches
branch_func_tmps[b] = MLIR.IR.with_block(MLIR.IR.body(branch_mods[b])) do
return MLIR.Dialects.func.func_(;
branch_func_tmps[b] = MLIR.IR.@scope MLIR.IR.body(branch_mods[b]) begin
MLIR.Dialects.func.func_(;
sym_name=string(branch_fns[b]) * "_branch$(b)_tmp",
function_type=MLIR.IR.FunctionType(input_types, []),
body=MLIR.IR.Region(),
Expand Down Expand Up @@ -3047,8 +3047,8 @@ result = Ops.case(
for b in 1:n_branches
branch_out_types = [mlir_type(tr) for tr in branch_corrected_linear_results[b]]

branch_fn_compiled = MLIR.IR.with_block(MLIR.IR.body(branch_mods[b])) do
return MLIR.Dialects.func.func_(;
branch_fn_compiled = MLIR.IR.@scope MLIR.IR.body(branch_mods[b]) begin
MLIR.Dialects.func.func_(;
sym_name=Reactant.TracedUtils.__lookup_unique_name_in_module(
branch_mods[b], string(branch_fns[b]) * "_branch$(b)"
),
Expand Down Expand Up @@ -3296,15 +3296,11 @@ end
)

sym_name = Reactant.TracedUtils.__lookup_unique_name_in_module(mod, sym_name)

mesh_op = MLIR.IR.with_module(mod) do
return MLIR.Dialects.sdy.mesh(; sym_name, mesh=mesh_attr, location)
end
mesh_op = MLIR.Dialects.sdy.mesh(; sym_name, mesh=mesh_attr, location)

# mesh_op needs to be moved to the beginning of the module
mesh_op = MLIR.IR.rmfromparent!(mesh_op)
mod_body = MLIR.IR.body(mod)
pushfirst!(mod_body, mesh_op)
pushfirst!(MLIR.IR.body(mod), mesh_op)

# We return the name of the mesh, since the operation is a Symbol op
return (;
Expand Down
33 changes: 17 additions & 16 deletions src/Sharding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -836,23 +836,24 @@ function HloSharding(sharding::DimsSharding, size_x)
end

function Base.convert(::Type{HloSharding}, sharding::NamedSharding)
MLIR.IR.with_context(; allow_use_existing=true) do ctx
mesh_op = Reactant.Ops.mesh(
sharding.mesh; mod=MLIR.IR.Module(MLIR.IR.Location(; context=ctx))
)

tensor_sharding_attr, _ = get_tensor_sharding_attribute(
sharding, ctx, mesh_op.sym_name, mesh_op.mesh_attr, nothing; dialect=:sdy
)
MLIR.IR.@dispose ctx = Reactant.ReactantContext() mod = MLIR.IR.Module(
MLIR.IR.Location(; context=ctx)
) begin
MLIR.IR.@scope ctx mod begin
mesh_op = Reactant.Ops.mesh(sharding.mesh; mod=MLIR.IR.Module())
tensor_sharding_attr, _ = get_tensor_sharding_attribute(
sharding, ctx, mesh_op.sym_name, mesh_op.mesh_attr, nothing; dialect=:sdy
)

return HloSharding(
hlo_sharding_from_sdy_tensor_sharding_attr(
tensor_sharding_attr, mesh_op.mesh_attr
),
sharding.mesh,
sharding.is_closed,
sharding.priority,
)
return HloSharding(
hlo_sharding_from_sdy_tensor_sharding_attr(
tensor_sharding_attr, mesh_op.mesh_attr
),
sharding.mesh,
sharding.is_closed,
sharding.priority,
)
end
end
end

Expand Down
25 changes: 12 additions & 13 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,7 @@ function make_mlir_fn(
Ops.activate_constant_context!(fnbody)
@assert MLIR.IR.has_block()

# Explicitly don't use with_block to avoid creating a closure, which creates
# both compile-time and relocatability issues
MLIR.IR.activate(fnbody)

result = try
process_linear_args!(linear_args, fnbody, do_transpose, optimize_then_pad, inv_map)

Expand Down Expand Up @@ -538,8 +535,8 @@ function prepare_mlir_fn_args(
end
end

func = MLIR.IR.with_block(MLIR.IR.body(mod)) do
return MLIR.Dialects.func.func_(;
func = MLIR.IR.@scope MLIR.IR.body(mod) begin
MLIR.Dialects.func.func_(;
sym_name=name * "_tmp",
function_type=MLIR.IR.FunctionType(in_tys, Vector{MLIR.IR.Type}(undef, 0)),
body=MLIR.IR.Region(),
Expand Down Expand Up @@ -877,8 +874,8 @@ function finalize_mlir_fn(
MLIR.IR.deactivate(fnbody)
end

func2 = MLIR.IR.with_block(MLIR.IR.body(mod)) do
return MLIR.Dialects.func.func_(;
func2 = MLIR.IR.@scope MLIR.IR.body(mod) begin
MLIR.Dialects.func.func_(;
sym_name=__lookup_unique_name_in_module(mod, name),
function_type=MLIR.IR.FunctionType(in_tys, out_tys),
body=MLIR.IR.Region(),
Expand Down Expand Up @@ -1033,13 +1030,15 @@ end

function __lookup_unique_name_in_module(mod, name)
new_name = name
tab = MLIR.IR.SymbolTable(MLIR.IR.Operation(mod))
for i in 0:10000
new_name = i == 0 ? name : name * "_" * string(i)
MLIR.IR.mlirIsNull(MLIR.API.mlirSymbolTableLookup(tab, new_name)) && return new_name
MLIR.IR.@dispose tab = MLIR.IR.SymbolTable(mod) begin
for i in 0:10000
new_name = i == 0 ? name : name * "_" * string(i)
MLIR.IR.mlirIsNull(MLIR.API.mlirSymbolTableLookup(tab, new_name)) &&
return new_name
end
modstr = string(mod)
return error("Mod\n$modstr\nCould not find unique name for $name")
end
modstr = string(mod)
return error("Mod\n$modstr\nCould not find unique name for $name")
end

function __take_region(compiled_fn)
Expand Down
9 changes: 0 additions & 9 deletions src/mlir/IR/Block.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,3 @@ function current_block(; throw_error::Core.Bool=true)
end
return last(task_local_storage(:mlir_block)::Vector{Block})
end

function with_block(f, blk::Block)
activate(blk)
try
f()
finally
deactivate(blk)
end
end
29 changes: 0 additions & 29 deletions src/mlir/IR/Context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,32 +61,3 @@ function current_context(; throw_error::Core.Bool=true)
end
return last(task_local_storage(:mlir_context_stack)::Vector{Context})
end

function with_context(f, ctx::Context)
activate(ctx)
try
f()
finally
deactivate(ctx)
end
end

# TODO replace this method on all call sites for the one accepting a context argument
function with_context(f; allow_use_existing=false)
do_dispose = false
if allow_use_existing && has_context()
ctx = current_context()
else
ctx = Context(Reactant.registry[])
do_dispose = true
@ccall API.mlir_c.RegisterDialects(ctx::API.MlirContext)::Cvoid
end

activate(ctx)
try
return f(ctx)
finally
deactivate(ctx)
do_dispose && dispose(ctx)
end
end
6 changes: 3 additions & 3 deletions src/mlir/IR/IR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ end
# MLIR `Type` and `Module`
export Attribute, Block, Context, Dialect, Location, Operation, Region, Value
export activate, deactivate, dispose, enable_multithreading!
export context, current_context, has_context, with_context
export block, current_block, has_block, with_block
export current_module, has_module, with_module
export context, current_context, has_context
export block, current_block, has_block
export current_module, has_module
export type, settype!, location, typeid, dialect
export nattrs, getattr, setattr!, rmattr!
export nregions, region
Expand Down
9 changes: 0 additions & 9 deletions src/mlir/IR/Module.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,3 @@ function current_module(; throw_error::Core.Bool=true)
end
return last(task_local_storage(:mlir_module)::Vector{Module})
end

function with_module(f, blk::Module)
activate(blk)
try
f()
finally
deactivate(blk)
end
end
29 changes: 12 additions & 17 deletions src/mlir/IR/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,26 +87,21 @@ end

Activates `obj` for the duration of `body`, then deactivates it.
"""
macro scope(obj, body)
bodybody = if Base.isexpr(body, :block)
body.args
else
[body]
end
if Base.isexpr(obj, :(=))
prologue = esc(obj)
symbol = obj.args[1]
else
prologue = nothing
symbol = esc(obj)
end
macro scope(args...)
@assert length(args) >= 2

objs = args[1:(end - 1)]
body = last(args)

activations = [:($activate($(esc(obj)))) for obj in objs]
deactivations = [:($deactivate($(esc(obj)))) for obj in reverse(objs)]

quote
$prologue
activate($symbol)
$(activations...)
try
$(esc.(bodybody)...)
$(esc(body))
finally
deactivate($symbol)
$(deactivations...)
end
end
end
Loading
Loading