Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 59 additions & 40 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ function trace_function_definition(mod, expr)
end
end

cond_val(s) = :(@isdefined($s) ? $s : $MissingTracedValue())

function trace_while(expr; track_numbers, mincut, checkpointing, first_arg=nothing)
Meta.isexpr(expr, :while, 2) || error("expected while expr")
cond, body = expr.args
Expand All @@ -315,7 +317,6 @@ function trace_while(expr; track_numbers, mincut, checkpointing, first_arg=nothi
all_syms = Expr(:tuple, external_syms...)
args_names = Expr(:tuple, external_syms...)

cond_val(s) = :(@isdefined($s) ? $s : nothing)
args_init = Expr(:tuple, (:(Ref($(cond_val(s)))) for s in external_syms)...)

ref_syms = Symbol[Symbol(string(sym), "_ref") for sym in external_syms]
Expand Down Expand Up @@ -476,29 +477,30 @@ function trace_if(expr; store_last_line=nothing, depth=0, track_numbers)
if depth == 0
error_if_any_control_flow(expr)

counter = 0
expr = MacroTools.prewalk(expr) do x
counter += 1
if x isa Expr && x.head == :if && counter > 1
ex_new, dv, _ = trace_if(x; store_last_line, depth=depth + 1, track_numbers)
append!(discard_vars_from_expansion, dv)
return ex_new
end
return x
end
# counter = 0
# expr = MacroTools.prewalk(expr) do x
# counter += 1
# if x isa Expr && x.head == :if && counter > 1
# ex_new, dv, _ = trace_if(x; store_last_line, depth=depth + 1, track_numbers)
# append!(discard_vars_from_expansion, dv)
# return ex_new
# end
# return x
# end
end

cond_expr = remove_shortcircuiting(expr.args[1])
expr_args = filter(a -> !isa(a, Core.LineNumberNode), expr.args)
cond_expr = remove_shortcircuiting(expr_args[1])
condition_vars = [ExpressionExplorer.compute_symbols_state(cond_expr).references...]

true_block = if store_last_line !== nothing
if expr.args[2] isa Expr
@assert expr.args[2].head == :block "currently we only support blocks"
expr.args[2] = Expr(:block, expr.args[2].args...)
true_last_line = expr.args[2].args[end]
remaining_lines = expr.args[2].args[1:(end - 1)]
if expr_args[2] isa Expr
@assert expr_args[2].head == :block "currently we only support blocks"
expr_args[2] = Expr(:block, expr_args[2].args...)
true_last_line = expr_args[2].args[end]
remaining_lines = expr_args[2].args[1:(end - 1)]
else
true_last_line = expr.args[2]
true_last_line = expr_args[2]
remaining_lines = []
end
quote
Expand All @@ -507,7 +509,7 @@ function trace_if(expr; store_last_line=nothing, depth=0, track_numbers)
end
else
quote
$(expr.args[2])
$(expr_args[2])
nothing # explicitly return nothing to prevent branches from returning different types
end
end
Expand All @@ -519,13 +521,13 @@ function trace_if(expr; store_last_line=nothing, depth=0, track_numbers)
all_true_branch_vars = true_branch_input_list ∪ true_branch_assignments
true_branch_fn_name = gensym(:true_branch)

else_block, discard_vars, _ = if length(expr.args) == 3
if !(expr.args[3] isa Expr) || expr.args[3].head != :elseif
expr.args[3], [], nothing
else_block, discard_vars, _ = if length(expr_args) == 3
if !(expr_args[3] isa Expr) || expr_args[3].head != :elseif
expr_args[3], [], nothing
else
trace_if(expr.args[3]; store_last_line, depth=depth + 1, track_numbers)
trace_if(expr_args[3]; store_last_line, depth=depth + 1, track_numbers)
end
elseif length(expr.args) == 2
elseif length(expr_args) == 2
tmp_expr = []
for var in true_branch_assignments
push!(tmp_expr, :($(var) = $(var)))
Expand Down Expand Up @@ -609,22 +611,21 @@ function trace_if(expr; store_last_line=nothing, depth=0, track_numbers)

cond_name = gensym(:cond)

args_init = [cond_val(s) for s in all_input_vars]
reactant_code_block = quote
$(true_branch_fn)
$(false_branch_fn)
($(all_output_vars...),) = $(traced_if)(
$(cond_name),
$(true_branch_fn_name),
$(false_branch_fn_name),
($(all_input_vars...),);
($(args_init...),);
track_numbers=($(track_numbers)),
)
end

non_reactant_code_block = Expr(:if, cond_name, original_expr.args[2])
if length(original_expr.args) > 2 # has else block
append!(non_reactant_code_block.args, original_expr.args[3:end])
end
non_reactant_code_block = Expr(:if, original_expr.args...)
non_reactant_code_block.args[1] = cond_name

all_check_vars = [cond_name, all_input_vars..., condition_vars...]
unique!(all_check_vars)
Expand All @@ -640,7 +641,7 @@ function trace_if(expr; store_last_line=nothing, depth=0, track_numbers)

return quote
$(cond_name) = $(cond_expr)
if $(within_compile)() && $(any)($(is_traced), ($(all_check_vars...),))
if $(within_compile)() && $(any)($(is_traced), ($((cond_val.(all_check_vars))...),))
$(reactant_code_block)
else
$(non_reactant_code_block)
Expand Down Expand Up @@ -688,7 +689,7 @@ end

# Generate this dummy function and later we remove it during tracing
function traced_if(cond, true_fn, false_fn, args; track_numbers)
return cond ? true_fn(args) : false_fn(args)
return cond ? true_fn(args...) : false_fn(args...)
end

function traced_while end # defined inside Reactant.jl
Expand All @@ -713,17 +714,35 @@ function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars)
end
end

const CONTROL_FLOW_EXPRS = [:return, :break, :continue, :symbolicgoto]
const CONTROL_FLOW_EXPRS = Symbol[:return, :break, :continue, :symbolicgoto, :macrocall]

function error_if_any_control_flow(expr)
return MacroTools.postwalk(expr) do x
for head in CONTROL_FLOW_EXPRS
if Meta.isexpr(x, head)
error("Cannot use @trace on a block that contains a $head statement")
end
end
return x
function prewalk_until_function_boundary(f, expr)
if Meta.isexpr(expr, :function) ||
Meta.isexpr(expr, :(->)) ||
(Meta.isexpr(expr, :(=), 2) && Meta.isexpr(expr.args[1], :call))
return expr
end
if expr isa Expr
return f(Expr(expr.head, map(f, expr.args)...))
end
return f(expr)
end

error_if_any_control_flow(_) = nothing
function error_if_any_control_flow(expr::Expr)
if Meta.isexpr(expr, :function) ||
Meta.isexpr(expr, :(->)) ||
(Meta.isexpr(expr, :(=), 2) && Meta.isexpr(expr.args[1], :call))
return expr
end

head_idx = findfirst(==(expr.head), CONTROL_FLOW_EXPRS)
if !isnothing(head_idx)
head = CONTROL_FLOW_EXPRS[head_idx]
error("Cannot use @trace on a block that contains a $head statement")
end

return foreach(error_if_any_control_flow, expr.args)
end

"""
Expand Down
17 changes: 12 additions & 5 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ function mlir_type(x::Union{RNumber,RArray})::MLIR.IR.Type
return MLIR.IR.TensorType(collect(Int, size(x)), MLIR.IR.Type(unwrapped_eltype(x)))
end

mlir_type(::MissingTracedValue) = MLIR.IR.TensorType((), MLIR.IR.Type(Bool))
mlir_type(::MissingTracedValue) = MLIR.IR.TensorType(Int[], MLIR.IR.Type(Bool))

function mlir_type(RT::Type{<:RArray{T,N}}, shape) where {T,N}
@assert length(shape) == N
Expand Down Expand Up @@ -2260,9 +2260,11 @@ end
traced_args = Vector{Any}(undef, N)

for (i, prev) in enumerate(args)
@inbounds traced_args[i] = Reactant.make_tracer(
seen_args, prev, (), Reactant.NoStopTracedTrack; track_numbers
)
@inbounds traced_args[i] = if prev isa Ref && prev[] isa MissingTracedValue
Ref{Nothing}(nothing)
else
Reactant.make_tracer(seen_args, prev, (), Reactant.NoStopTracedTrack; track_numbers)
end
end

linear_args = Reactant.TracedType[]
Expand Down Expand Up @@ -2408,6 +2410,9 @@ end
if isnothing(path)
error("if_condition: could not find path for linear arg $i")
end
if arg isa MissingTracedValue
continue
end
Reactant.TracedUtils.set_mlir_data!(
arg,
only(
Expand Down Expand Up @@ -2724,13 +2729,15 @@ end

corrected_traced_results =
map(zip(traced_false_results, traced_true_results)) do (fr, tr)
if fr isa MissingTracedValue && tr isa MissingTracedValue
res = if fr isa MissingTracedValue && tr isa MissingTracedValue
return fr
elseif fr isa MissingTracedValue
return tr
else
return fr
end
# @something res MissingTracedValue()
res
end

@assert length(all_paths) == length(result_types)
Expand Down
4 changes: 3 additions & 1 deletion src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,9 @@ function set!(x, path, tostore; emptypath=false)
x = Reactant.Compiler.traced_getfield(x, p)
end

set_mlir_data!(x, tostore)
if is_traced(x)
set_mlir_data!(x, tostore)
end

return emptypath && set_paths!(x, ())
end
Expand Down
44 changes: 41 additions & 3 deletions test/core/control_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ function condition2_nested_if(x, y)
x_sum = sum(x)
@trace if x_sum > 0
y_sum = sum(y)
if y_sum > 0
@trace if y_sum > 0
z = x_sum + y_sum
else
z = x_sum - y_sum
Expand Down Expand Up @@ -387,8 +387,8 @@ function condition11_nested_ifff(x, y, z)
x_sum = sum(x)
@trace if x_sum > 0
y_sum = sum(y)
if y_sum > 0
if sum(z) > 0
@trace if y_sum > 0
@trace if sum(z) > 0
z = x_sum + y_sum + sum(z)
else
z = x_sum + y_sum
Expand Down Expand Up @@ -1110,3 +1110,41 @@ end
@test a_ra ≈ a
@test b_ra ≈ b
end

function myfunc_traced_if_in_for(x) # compute sum of positive elements in x
s = zero(eltype(x))
@trace for i in eachindex(x)
@allowscalar cond = x[i] > 0
@trace if cond
@allowscalar s += x[i]
end
end
return s
end

function nested_trace_if_for(u, mask)
acc = zero(eltype(u))
n = length(u)
keep = sum(mask) > 0

@trace if keep
@trace for i in 1:n
acc = acc + @allowscalar(u[i])
end
out = acc
else
out = acc
end

return out
end

@testset "for in if and if in for" begin
u = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float64, 16))
mask = Reactant.ConcreteRArray(collect(rand(Float64, 16) .> 0.5))

@test @jit(nested_trace_if_for(u, mask)) ≈ nested_trace_if_for(Array(u), Array(mask))

x = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float64, 10))
@test @jit(myfunc_traced_if_in_for(x)) ≈ myfunc_traced_if_in_for(Array(x))
end
Loading