Description
This question follows a discussion on discourse and @gdalle suggested I post a MWE example here.
I'm trying to solve a nonlinear least squares system with NonlinearSolve, where the objective function relies on FFT computation (along with mutation and buffer storage for efficiency).
Currently, the example below works quite well with AutoFiniteDiff. I wanted to get it to work with a true autodiff backend for efficiency / accuracy.
ForwardDiff is possible, but needs substantial code modification (below is a very simplified extract) to have generic types and generic buffers. So instead, I was trying Enzyme.
With the help of @danielwe on discourse we got Enzyme working with the FFT (his code appears at the top of the example below), but now I get a harder looking problem.
Here is my (M)WE:
using Test, Random, FFTW, LinearAlgebra
using NonlinearSolve
using AbstractFFTs, Enzyme
import Enzyme: EnzymeCore
Enzyme.EnzymeRules.inactive_type(::Type{<:AbstractFFTs.Plan}) = true
EnzymeRules.inactive(::typeof(plan_fft), args...) = true
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
::Const{typeof(*)},
::Type,
P::Const{<:AbstractFFTs.Plan{T}},
x::Duplicated{<:StridedArray{T}},
) where {T}
# we can never skip the forward pass because we don't know a priori whether P is an
# in-place plan, and in-place mutation for non-NoNeed arguments must be performed
# regardless of needs_primal
xval = x.val
yval = P.val * xval
inplace = Base.mightalias(yval, xval)
if inplace
@assert yval === xval # otherwise I don't know what to do
end
needs_primal = EnzymeRules.needs_primal(config)
primal = needs_primal ? yval : nothing
shadow = if EnzymeRules.needs_shadow(config)
if needs_primal || inplace
make_zero(yval)
else # might as well reuse yval as shadow then
make_zero!(yval)
yval
end
else
nothing
end
tape = (inplace, shadow) # since * is linear, we don't care whether x is overwritten
return EnzymeRules.AugmentedReturn(primal, shadow, tape)
end
function EnzymeRules.reverse(
::EnzymeRules.RevConfigWidth{1},
::Const{typeof(*)},
::Type,
tape,
P::Const{<:AbstractFFTs.Plan{T}},
x::Duplicated{<:StridedArray{T}},
) where {T}
inplace, shadow = tape
Padj = adjoint(P.val)
dx = x.dval
if inplace
out = Padj * dx
@assert out === dx # sanity check that adjoint(P) is in-place when P is
end
if !isnothing(shadow)
dx .+= Padj * shadow # mul! not yet supported for adjoint plans
make_zero!(shadow)
end
return (nothing, nothing)
end
function signal!(sig, E, G)
sig .= E .* abs2.(G)
end
abstract type AbstractTraceMaker end
struct SimpleTraceMaker{TFT} <: AbstractTraceMaker
τ::Vector{Float64}
ω::Vector{Float64}
trace::Array{Float64,2}
Ebuf::Vector{Complex{Float64}}
Gbuf::Vector{Complex{Float64}}
sigbuf::Vector{Complex{Float64}}
FT::TFT
end
function SimpleTraceMaker(τ, ω)
trace = zeros(Float64, length(ω), length(τ))
Ebuf = zeros(Complex{Float64}, length(ω))
Gbuf = zeros(Complex{Float64}, length(ω))
sigbuf = zeros(Complex{Float64}, length(ω))
FT = FFTW.plan_fft(Gbuf)
inv(FT)
SimpleTraceMaker(τ, ω, trace, Ebuf, Gbuf, sigbuf, FT)
end
function (tm::SimpleTraceMaker)(Eω)
mul!(tm.Ebuf, tm.FT, Eω)
for (i, τ) in enumerate(tm.τ)
tm.sigbuf .= Eω .* exp.(1im .* tm.ω .* τ)
mul!(tm.Gbuf, tm.FT, tm.sigbuf)
signal!(tm.sigbuf, tm.Ebuf, tm.Gbuf)
ldiv!(tm.Gbuf, tm.FT, tm.sigbuf)
tm.trace[:,i] .= abs2.(tm.Gbuf)
end
tm.trace ./= maximum(tm.trace)
tm.trace
end
struct Objective{TM<:AbstractTraceMaker}
trace::Array{Float64,2}
Eωbuf::Vector{Complex{Float64}}
diffbuf::Vector{Float64}
tm::TM
s::Float64
end
function Objective(trace, tm)
Eωbuf = zeros(Complex{Float64}, length(tm.ω))
diffbuf = zeros(length(trace))
s = 1.0 / (sqrt(length(trace)) * maximum(trace))
Objective(copy(trace), Eωbuf, diffbuf, tm, s)
end
function (o::Objective)(out, u, p)
N = length(o.Eωbuf)
@views o.Eωbuf .= u[1:N] .* exp.(1im .* u[N+1:2N])
testtrace = o.tm(o.Eωbuf)
µ = (sum(testtrace[idc] * o.trace[idc] for idc in CartesianIndices(testtrace))
/ sum(testtrace[idc]^2 for idc in CartesianIndices(testtrace)))
for idc in CartesianIndices(testtrace)
out[LinearIndices(testtrace)[idc]] = (o.trace[idc] - µ * testtrace[idc]) * o.s
end
nothing
end
function testfield()
N = 128
dt = 0.3e-15
t = collect(1:N) .* dt
t0 = ifftshift(t)[1]
t .-= t0
dω = 2π/(dt*N)
ω = collect(1:N) .* dω
ω0 = ifftshift(ω)[1]
ω .-= ω0
pt = complex.(exp.(-0.5 .* (t ./ 1.5e-15).^2))
Eω = fftshift(ifft(ifftshift(pt)))
Eω ./= sqrt(maximum(abs2.(Eω)))
τ = collect(range(-15e-15, 15e-15, 87))
tm = SimpleTraceMaker(τ, ω)
trace = copy(tm(Eω))
Eω0 = rand(Xoshiro(123), Complex{Float64}, length(Eω))
tm, trace, Eω, Eω0
end
tm, trace, Eω, Eω0 = testfield()
o = Objective(trace, tm)
u0 = vcat(abs.(Eω0), angle.(Eω0))
func = NonlinearFunction(o, resid_prototype = zeros(length(o.trace)))
problem = NonlinearLeastSquaresProblem(func, u0)
res = solve(problem, LevenbergMarquardt(autodiff=AutoEnzyme(;mode=EnzymeCore.Reverse, function_annotation=EnzymeCore.Duplicated)); maxiters=30, show_trace=Val(true))
ferr = sqrt(sum(res.resid.^2) / (length(trace) * maximum(trace)^2))
Eret = res.u[1:length(Eω)] .* exp.(1im .* res.u[length(Eω)+1:2*length(Eω)])
Eret ./= maximum(abs.(Eret))
Eω ./= maximum(abs.(Eω))
@test ferr ≈ 0.0 atol=4e-10
@test isapprox(abs.(Eω), abs.(Eret), rtol=5e-3, atol=1e-10)
The error I get involves the printing of lots of LLVM-like output, and ends like this
%818 = bitcast [16 x i64] %484 to i64, !dbg !11249
%_augmented398 = call { i8*, i32, i32 } %817(i64 %485, i64 %818) [ "jl_roots"({ i8*, {} addrspace(10)* } %468, [16 x { i8*, {} addrspace(10)* }] %467) ], !dbg !11249
%subcache399 = extractvalue { i8*, i32, i32 } %_augmented398, 0, !dbg !11249
%819 = extractvalue { i8*, i32, i32 } %_augmented398, 1, !dbg !11249
%820 = getelementptr inbounds i8, i8 addrspace(11)* %13, i64 44, !dbg !11337
%821 = bitcast i8 addrspace(11)* %820 to i32 addrspace(11)*, !dbg !11337
%822 = load i32, i32 addrspace(11)* %821, align 4, !dbg !11337, !tbaa !127, !alias.scope !11083, !noalias !11086, !enzyme_type !133, !enzyme_inactive !0, !enzymejl_byref_BITS_VALUE !0, !enzymejl_source_type_Int32 !0
%823 = icmp eq i32 %819, %822, !dbg !11385
br i1 %823, label %L108, label %L98, !dbg !11250
}
ERROR: LLVM error: function failed verification (2)
Stacktrace:
[1] handle_error(reason::Cstring)
@ LLVM ~/.julia/packages/LLVM/xTJfF/src/core/context.jl:194
I would greatly appreciate any help or tips on this. As this is not really a bug, I don't mind if you decide to just close this issue. I can survive with AutoFiniteDiff. But if this could be made to work it would be excellent!