Skip to content

Autodifferentiation with FFT and Enzyme? #597

Open
@jtravs

Description

@jtravs

This question follows a discussion on discourse and @gdalle suggested I post a MWE example here.

I'm trying to solve a nonlinear least squares system with NonlinearSolve, where the objective function relies on FFT computation (along with mutation and buffer storage for efficiency).

Currently, the example below works quite well with AutoFiniteDiff. I wanted to get it to work with a true autodiff backend for efficiency / accuracy.

ForwardDiff is possible, but needs substantial code modification (below is a very simplified extract) to have generic types and generic buffers. So instead, I was trying Enzyme.

With the help of @danielwe on discourse we got Enzyme working with the FFT (his code appears at the top of the example below), but now I get a harder looking problem.

Here is my (M)WE:

using Test, Random, FFTW, LinearAlgebra
using NonlinearSolve
using AbstractFFTs, Enzyme
import Enzyme: EnzymeCore


Enzyme.EnzymeRules.inactive_type(::Type{<:AbstractFFTs.Plan}) = true
EnzymeRules.inactive(::typeof(plan_fft), args...) = true

function EnzymeRules.augmented_primal(
    config::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(*)},
    ::Type,
    P::Const{<:AbstractFFTs.Plan{T}},
    x::Duplicated{<:StridedArray{T}},
) where {T}
    # we can never skip the forward pass because we don't know a priori whether P is an
    # in-place plan, and in-place mutation for non-NoNeed arguments must be performed
    # regardless of needs_primal
    xval = x.val
    yval = P.val * xval
    inplace = Base.mightalias(yval, xval)
    if inplace
        @assert yval === xval  # otherwise I don't know what to do
    end
    needs_primal = EnzymeRules.needs_primal(config)
    primal = needs_primal ? yval : nothing
    shadow = if EnzymeRules.needs_shadow(config)
        if needs_primal || inplace
            make_zero(yval)
        else  # might as well reuse yval as shadow then
            make_zero!(yval)
            yval
        end
    else
        nothing
    end
    tape = (inplace, shadow)  # since * is linear, we don't care whether x is overwritten
    return EnzymeRules.AugmentedReturn(primal, shadow, tape)
end

function EnzymeRules.reverse(
    ::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(*)},
    ::Type,
    tape,
    P::Const{<:AbstractFFTs.Plan{T}},
    x::Duplicated{<:StridedArray{T}},
) where {T}
    inplace, shadow = tape
    Padj = adjoint(P.val)
    dx = x.dval
    if inplace
        out = Padj * dx
        @assert out === dx  # sanity check that adjoint(P) is in-place when P is
    end
    if !isnothing(shadow)
        dx .+= Padj * shadow  # mul! not yet supported for adjoint plans
        make_zero!(shadow)
    end
    return (nothing, nothing)
end


function signal!(sig, E, G)
    sig .= E .* abs2.(G)
end

abstract type AbstractTraceMaker end

struct SimpleTraceMaker{TFT} <: AbstractTraceMaker
    τ::Vector{Float64}
    ω::Vector{Float64}
    trace::Array{Float64,2}
    Ebuf::Vector{Complex{Float64}}
    Gbuf::Vector{Complex{Float64}}
    sigbuf::Vector{Complex{Float64}}
    FT::TFT
end

function SimpleTraceMaker(τ, ω)
    trace = zeros(Float64, length(ω), length(τ))
    Ebuf = zeros(Complex{Float64}, length(ω))
    Gbuf = zeros(Complex{Float64}, length(ω))
    sigbuf = zeros(Complex{Float64}, length(ω))
    FT = FFTW.plan_fft(Gbuf)
    inv(FT)
    SimpleTraceMaker(τ, ω, trace, Ebuf, Gbuf, sigbuf, FT)
end

function (tm::SimpleTraceMaker)(Eω)
    mul!(tm.Ebuf, tm.FT, Eω)
    for (i, τ) in enumerate(tm.τ)
        tm.sigbuf .=.* exp.(1im .* tm.ω .* τ)
        mul!(tm.Gbuf, tm.FT, tm.sigbuf)
        signal!(tm.sigbuf, tm.Ebuf, tm.Gbuf)
        ldiv!(tm.Gbuf, tm.FT, tm.sigbuf)
        tm.trace[:,i] .= abs2.(tm.Gbuf)
    end
    tm.trace ./= maximum(tm.trace)
    tm.trace
end


struct Objective{TM<:AbstractTraceMaker}
    trace::Array{Float64,2}
    Eωbuf::Vector{Complex{Float64}}
    diffbuf::Vector{Float64}
    tm::TM
    s::Float64
end

function Objective(trace, tm)
    Eωbuf = zeros(Complex{Float64}, length(tm.ω))
    diffbuf = zeros(length(trace))
    s = 1.0 / (sqrt(length(trace)) * maximum(trace))
    Objective(copy(trace), Eωbuf, diffbuf, tm, s)
end

function (o::Objective)(out, u, p)
    N = length(o.Eωbuf)
    @views o.Eωbuf .= u[1:N] .* exp.(1im .* u[N+1:2N])
    testtrace = o.tm(o.Eωbuf)
    µ = (sum(testtrace[idc] * o.trace[idc] for idc in CartesianIndices(testtrace)) 
           / sum(testtrace[idc]^2 for idc in CartesianIndices(testtrace)))
    for idc in CartesianIndices(testtrace)
        out[LinearIndices(testtrace)[idc]] = (o.trace[idc] - µ * testtrace[idc]) * o.s
    end
    nothing
end


function testfield()
    N = 128
    dt = 0.3e-15
    t = collect(1:N) .* dt
    t0 = ifftshift(t)[1]
    t .-= t0
    dω = 2π/(dt*N)
    ω = collect(1:N) .* dω
    ω0 = ifftshift(ω)[1]
    ω .-= ω0

    pt = complex.(exp.(-0.5 .* (t ./ 1.5e-15).^2))
    Eω = fftshift(ifft(ifftshift(pt)))
    Eω ./= sqrt(maximum(abs2.(Eω)))

    τ = collect(range(-15e-15, 15e-15, 87))
    tm = SimpleTraceMaker(τ, ω)
    trace = copy(tm(Eω))

    Eω0 = rand(Xoshiro(123), Complex{Float64}, length(Eω))
    tm, trace, Eω, Eω0
end


tm, trace, Eω, Eω0 = testfield()
o = Objective(trace, tm)
u0 = vcat(abs.(Eω0), angle.(Eω0))

func = NonlinearFunction(o, resid_prototype = zeros(length(o.trace)))
problem = NonlinearLeastSquaresProblem(func, u0)
res = solve(problem, LevenbergMarquardt(autodiff=AutoEnzyme(;mode=EnzymeCore.Reverse, function_annotation=EnzymeCore.Duplicated)); maxiters=30, show_trace=Val(true))

ferr = sqrt(sum(res.resid.^2) / (length(trace) * maximum(trace)^2))
Eret = res.u[1:length(Eω)] .* exp.(1im .* res.u[length(Eω)+1:2*length(Eω)])
Eret ./= maximum(abs.(Eret))
Eω ./= maximum(abs.(Eω))

@test ferr  0.0 atol=4e-10
@test isapprox(abs.(Eω), abs.(Eret), rtol=5e-3, atol=1e-10)

The error I get involves the printing of lots of LLVM-like output, and ends like this

  %818 = bitcast [16 x i64] %484 to i64, !dbg !11249
  %_augmented398 = call { i8*, i32, i32 } %817(i64 %485, i64 %818) [ "jl_roots"({ i8*, {} addrspace(10)* } %468, [16 x { i8*, {} addrspace(10)* }] %467) ], !dbg !11249
  %subcache399 = extractvalue { i8*, i32, i32 } %_augmented398, 0, !dbg !11249
  %819 = extractvalue { i8*, i32, i32 } %_augmented398, 1, !dbg !11249
  %820 = getelementptr inbounds i8, i8 addrspace(11)* %13, i64 44, !dbg !11337
  %821 = bitcast i8 addrspace(11)* %820 to i32 addrspace(11)*, !dbg !11337
  %822 = load i32, i32 addrspace(11)* %821, align 4, !dbg !11337, !tbaa !127, !alias.scope !11083, !noalias !11086, !enzyme_type !133, !enzyme_inactive !0, !enzymejl_byref_BITS_VALUE !0, !enzymejl_source_type_Int32 !0
  %823 = icmp eq i32 %819, %822, !dbg !11385
  br i1 %823, label %L108, label %L98, !dbg !11250
}

ERROR: LLVM error: function failed verification (2)
Stacktrace:
 [1] handle_error(reason::Cstring)
   @ LLVM ~/.julia/packages/LLVM/xTJfF/src/core/context.jl:194

I would greatly appreciate any help or tips on this. As this is not really a bug, I don't mind if you decide to just close this issue. I can survive with AutoFiniteDiff. But if this could be made to work it would be excellent!

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions