-
Notifications
You must be signed in to change notification settings - Fork 48
Open
Description
Another problem that came up in our initial conditions computation where we don't care so much about not allocating. Haven't actually found a workaround for this one yet. In this case it seems the array that's allocated in the function isn't being traced correctly by the @jit. MWE below
using Reactant
using LinearAlgebra
# Create arrays on Reactant
A = Reactant.to_rarray(rand(Float32, 50, 100))
x = Reactant.to_rarray(rand(Float32, 100, 1))
println("A type: ", typeof(A))
println("x type: ", typeof(x))
# mul! with output allocation inside @jit
function test_mul(A, x)
output = Reactant.to_rarray(zeros(Float32, size(A, 1), size(x, 2)))
mul!(output, A, x)
return output
end
result = @jit test_mul(A, x)ERROR: MethodError: no method matching overloaded_mul!(::ConcretePJRTArray{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::Bool, ::Bool)
The function `overloaded_mul!` exists, but no method is defined for this combination of argument types.
Closest candidates are:
overloaded_mul!(::Reactant.TracedRArray{T, 2} where T, ::AbstractMatrix, ::AbstractMatrix, ::Number, ::Number)
@ Reactant ~/.julia/packages/Reactant/itXRw/src/stdlibs/LinearAlgebra.jl:262
overloaded_mul!(::Reactant.TracedRArray{T, 2} where T, ::AbstractMatrix, ::AbstractMatrix, ::Number)
@ Reactant ~/.julia/packages/Reactant/itXRw/src/stdlibs/LinearAlgebra.jl:262
overloaded_mul!(::Reactant.TracedRArray{T, 2}, ::AbstractMatrix, ::AbstractVector, ::Number, ::Number) where T
@ Reactant ~/.julia/packages/Reactant/itXRw/src/stdlibs/LinearAlgebra.jl:251
...
Stacktrace:
[1] call_with_native(::Any, ::Any, ::Vararg{Any}; kwargs...)
@ Reactant ~/.julia/packages/Reactant/itXRw/src/utils.jl:60
[2] #mul!
@ ~/.julia/packages/Reactant/itXRw/src/Overlay.jl:171
[3] call_with_reactant(::typeof(mul!), ::Matrix{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::Bool, ::Bool)
@ Reactant ~/.julia/packages/Reactant/itXRw/src/utils.jl:1090
[4] #mul!
@ ~/.julia/packages/Reactant/itXRw/src/Overlay.jl:188
[5] test_mul
@ ./REPL[7]:3
[6] call_with_reactant(::typeof(test_mul), ::Reactant.TracedRArray{Float32, 2}, ::Reactant.TracedRArray{Float32, 2})
@ Reactant ~/.julia/packages/Reactant/itXRw/src/utils.jl:1090
[7] make_mlir_fn(f::typeof(test_mul), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/itXRw/src/TracedUtils.jl:355
[8] make_mlir_fn
@ ~/.julia/packages/Reactant/itXRw/src/TracedUtils.jl:284 [inlined]
[9] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(test_mul), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/itXRw/src/Compiler.jl:1770
[10] compile_mlir!
@ ~/.julia/packages/Reactant/itXRw/src/Compiler.jl:1732 [inlined]
[11] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/itXRw/src/Compiler.jl:3738
[12] compile_xla
@ ~/.julia/packages/Reactant/itXRw/src/Compiler.jl:3711 [inlined]
[13] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/itXRw/src/Compiler.jl:3826
[14] top-level scope
@ ~/.julia/packages/Reactant/itXRw/src/Compiler.jl:2881
Some type information was truncated. Use `show(err)` to see complete types.Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels